From badcea0d06889de1b1db6ccf36a7682c5d7aedb8 Mon Sep 17 00:00:00 2001 From: charles jonas Date: Tue, 5 Aug 2025 09:16:44 -0600 Subject: [PATCH 1/3] change _utils.is_async_callable iscoroutinefunction instead of --- pydantic_ai_slim/pydantic_ai/ag_ui.py | 22 ++++- tests/test_ag_ui.py | 119 ++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index b768029418..8437010b73 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -6,9 +6,10 @@ from __future__ import annotations +import inspect import json import uuid -from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence from dataclasses import Field, dataclass, replace from http import HTTPStatus from typing import ( @@ -23,6 +24,8 @@ from pydantic import BaseModel, ValidationError +from pydantic_ai import _utils + from ._agent_graph import CallToolsNode, ModelRequestNode from .agent import AbstractAgent, AgentRun from .exceptions import UserError @@ -108,6 +111,7 @@ 'StateDeps', 'StateHandler', 'AGUIApp', + 'AgentRunCallback', 'handle_ag_ui_request', 'run_ag_ui', ] @@ -115,6 +119,9 @@ SSE_CONTENT_TYPE: Final[str] = 'text/event-stream' """Content type header value for Server-Sent Events (SSE).""" +AgentRunCallback = Callable[[AgentRun[Any, Any]], None | Awaitable[None]] +"""Callback function type that receives the completed AgentRun. Can be sync or async.""" + class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): """ASGI application for running Pydantic AI agents with AG-UI protocol support.""" @@ -162,7 +169,6 @@ def __init__( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. - debug: Boolean indicating if debug tracebacks should be returned on errors. routes: A list of routes to serve incoming HTTP and WebSocket requests. middleware: A list of middleware to run for every request. A starlette application will always @@ -221,6 +227,7 @@ async def handle_ag_ui_request( usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + on_complete: AgentRunCallback | None = None, ) -> Response: """Handle an AG-UI request by running the agent and returning a streaming response. @@ -237,6 +244,8 @@ async def handle_ag_ui_request( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + on_complete: Optional callback function called when the agent run completes successfully. + The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data. Returns: A streaming Starlette response with AG-UI protocol events. @@ -264,6 +273,7 @@ async def handle_ag_ui_request( usage=usage, infer_name=infer_name, toolsets=toolsets, + on_complete=on_complete, ), media_type=accept, ) @@ -282,6 +292,7 @@ async def run_ag_ui( usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + on_complete: AgentRunCallback | None = None, ) -> AsyncIterator[str]: """Run the agent with the AG-UI run input and stream AG-UI protocol events. @@ -299,6 +310,8 @@ async def run_ag_ui( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + on_complete: Optional callback function called when the agent run completes successfully. + The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data. Yields: Streaming event chunks encoded as strings according to the accept header value. @@ -357,6 +370,11 @@ async def run_ag_ui( ) as run: async for event in _agent_stream(run): yield encoder.encode(event) + if on_complete is not None: + if inspect.iscoroutinefunction(on_complete): + await on_complete(run) + else: + await _utils.run_in_executor(on_complete, run) except _RunError as e: yield encoder.encode( RunErrorEvent(message=e.message, code=e.code), diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 12f68d42bb..51db541639 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -1252,3 +1252,122 @@ async def test_to_ag_ui() -> None: events.append(json.loads(line.removeprefix('data: '))) assert events == simple_result() + + +async def test_callback_sync() -> None: + """Test that sync callbacks work correctly.""" + from pydantic_ai.agent import AgentRun + + captured_runs: list[AgentRun[Any, Any]] = [] + + def sync_callback(agent_run: AgentRun[Any, Any]) -> None: + captured_runs.append(agent_run) + + agent = Agent(TestModel()) + run_input = create_input( + UserMessage( + id='msg1', + content='Hello!', + ) + ) + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=sync_callback): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify callback was called + assert len(captured_runs) == 1 + agent_run = captured_runs[0] + + # Verify we can access messages + assert agent_run.result is not None, 'AgentRun result should be available in callback' + messages = agent_run.result.all_messages() + assert len(messages) >= 1 + + # Verify events were still streamed normally + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert events[-1]['type'] == 'RUN_FINISHED' + + +async def test_callback_async() -> None: + """Test that async callbacks work correctly.""" + from pydantic_ai.agent import AgentRun + + captured_runs: list[AgentRun[Any, Any]] = [] + + async def async_callback(agent_run: AgentRun[Any, Any]) -> None: + captured_runs.append(agent_run) + + agent = Agent(TestModel()) + run_input = create_input( + UserMessage( + id='msg1', + content='Hello!', + ) + ) + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=async_callback): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify callback was called + assert len(captured_runs) == 1 + agent_run = captured_runs[0] + + # Verify we can access messages + assert agent_run.result is not None, 'AgentRun result should be available in callback' + messages = agent_run.result.all_messages() + assert len(messages) >= 1 + + # Verify events were still streamed normally + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert events[-1]['type'] == 'RUN_FINISHED' + + +async def test_callback_none() -> None: + """Test that passing None for callback works (backwards compatibility).""" + + agent = Agent(TestModel()) + run_input = create_input( + UserMessage( + id='msg1', + content='Hello!', + ) + ) + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=None): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify events were still streamed normally + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert events[-1]['type'] == 'RUN_FINISHED' + + +async def test_callback_with_error() -> None: + """Test that callbacks are not called when errors occur.""" + from pydantic_ai.agent import AgentRun + + captured_runs: list[AgentRun[Any, Any]] = [] + + def error_callback(agent_run: AgentRun[Any, Any]) -> None: + captured_runs.append(agent_run) + + agent = Agent(TestModel()) + # Empty messages should cause an error + run_input = create_input() # No messages will cause _NoMessagesError + + events: list[dict[str, Any]] = [] + async for event in run_ag_ui(agent, run_input, on_complete=error_callback): + events.append(json.loads(event.removeprefix('data: '))) + + # Verify callback was not called due to error + assert len(captured_runs) == 0 + + # Verify error event was sent + assert len(events) > 0 + assert events[0]['type'] == 'RUN_STARTED' + assert any(event['type'] == 'RUN_ERROR' for event in events) From d5e8e538140a375c846cfa6dc2a3b20b63c1a160 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 10 Sep 2025 17:46:16 +0000 Subject: [PATCH 2/3] Tweaks --- pydantic_ai_slim/pydantic_ai/ag_ui.py | 33 +++++------ tests/test_ag_ui.py | 80 +++++++++------------------ 2 files changed, 43 insertions(+), 70 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 3217114494..9f449e9e9a 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -6,7 +6,6 @@ from __future__ import annotations -import inspect import json import uuid from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence @@ -18,16 +17,16 @@ Final, Generic, Protocol, + TypeAlias, TypeVar, runtime_checkable, ) from pydantic import BaseModel, ValidationError -from pydantic_ai import _utils - +from . import _utils from ._agent_graph import CallToolsNode, ModelRequestNode -from .agent import AbstractAgent, AgentRun +from .agent import AbstractAgent, AgentRun, AgentRunResult from .exceptions import UserError from .messages import ( FunctionToolResultEvent, @@ -110,7 +109,7 @@ 'StateDeps', 'StateHandler', 'AGUIApp', - 'AgentRunCallback', + 'OnCompleteFunc', 'handle_ag_ui_request', 'run_ag_ui', ] @@ -118,8 +117,8 @@ SSE_CONTENT_TYPE: Final[str] = 'text/event-stream' """Content type header value for Server-Sent Events (SSE).""" -AgentRunCallback = Callable[[AgentRun[Any, Any]], None | Awaitable[None]] -"""Callback function type that receives the completed AgentRun. Can be sync or async.""" +OnCompleteFunc: TypeAlias = Callable[[AgentRunResult[Any]], None] | Callable[[AgentRunResult[Any]], Awaitable[None]] +"""Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async.""" class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): @@ -168,6 +167,7 @@ def __init__( usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + debug: Boolean indicating if debug tracebacks should be returned on errors. routes: A list of routes to serve incoming HTTP and WebSocket requests. middleware: A list of middleware to run for every request. A starlette application will always @@ -226,7 +226,7 @@ async def handle_ag_ui_request( usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - on_complete: AgentRunCallback | None = None, + on_complete: OnCompleteFunc | None = None, ) -> Response: """Handle an AG-UI request by running the agent and returning a streaming response. @@ -244,7 +244,7 @@ async def handle_ag_ui_request( infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. on_complete: Optional callback function called when the agent run completes successfully. - The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data. + The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data. Returns: A streaming Starlette response with AG-UI protocol events. @@ -291,7 +291,7 @@ async def run_ag_ui( usage: RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - on_complete: AgentRunCallback | None = None, + on_complete: OnCompleteFunc | None = None, ) -> AsyncIterator[str]: """Run the agent with the AG-UI run input and stream AG-UI protocol events. @@ -310,7 +310,7 @@ async def run_ag_ui( infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. on_complete: Optional callback function called when the agent run completes successfully. - The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data. + The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data. Yields: Streaming event chunks encoded as strings according to the accept header value. @@ -369,11 +369,12 @@ async def run_ag_ui( ) as run: async for event in _agent_stream(run): yield encoder.encode(event) - if on_complete is not None: - if inspect.iscoroutinefunction(on_complete): - await on_complete(run) - else: - await _utils.run_in_executor(on_complete, run) + + if on_complete is not None and run.result is not None: + if _utils.is_async_callable(on_complete): + await on_complete(run.result) + else: + await _utils.run_in_executor(on_complete, run.result) except _RunError as e: yield encoder.encode( RunErrorEvent(message=e.message, code=e.code), diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 47604b42bc..fb3e25bef3 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -19,7 +19,7 @@ from pydantic import BaseModel from pydantic_ai._run_context import RunContext -from pydantic_ai.agent import Agent +from pydantic_ai.agent import Agent, AgentRunResult from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ModelMessage from pydantic_ai.models.function import ( @@ -57,6 +57,7 @@ from pydantic_ai.ag_ui import ( SSE_CONTENT_TYPE, + OnCompleteFunc, StateDeps, run_ag_ui, ) @@ -96,11 +97,14 @@ def simple_result() -> Any: async def run_and_collect_events( - agent: Agent[AgentDepsT, OutputDataT], *run_inputs: RunAgentInput, deps: AgentDepsT = None + agent: Agent[AgentDepsT, OutputDataT], + *run_inputs: RunAgentInput, + deps: AgentDepsT = None, + on_complete: OnCompleteFunc | None = None, ) -> list[dict[str, Any]]: events = list[dict[str, Any]]() for run_input in run_inputs: - async for event in run_ag_ui(agent, run_input, deps=deps): + async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete): events.append(json.loads(event.removeprefix('data: '))) return events @@ -1260,12 +1264,11 @@ async def test_to_ag_ui() -> None: async def test_callback_sync() -> None: """Test that sync callbacks work correctly.""" - from pydantic_ai.agent import AgentRun - captured_runs: list[AgentRun[Any, Any]] = [] + captured_results: list[AgentRunResult[Any]] = [] - def sync_callback(agent_run: AgentRun[Any, Any]) -> None: - captured_runs.append(agent_run) + def sync_callback(run_result: AgentRunResult[Any]) -> None: + captured_results.append(run_result) agent = Agent(TestModel()) run_input = create_input( @@ -1275,17 +1278,14 @@ def sync_callback(agent_run: AgentRun[Any, Any]) -> None: ) ) - events: list[dict[str, Any]] = [] - async for event in run_ag_ui(agent, run_input, on_complete=sync_callback): - events.append(json.loads(event.removeprefix('data: '))) + events = await run_and_collect_events(agent, run_input, on_complete=sync_callback) # Verify callback was called - assert len(captured_runs) == 1 - agent_run = captured_runs[0] + assert len(captured_results) == 1 + run_result = captured_results[0] # Verify we can access messages - assert agent_run.result is not None, 'AgentRun result should be available in callback' - messages = agent_run.result.all_messages() + messages = run_result.all_messages() assert len(messages) >= 1 # Verify events were still streamed normally @@ -1296,12 +1296,11 @@ def sync_callback(agent_run: AgentRun[Any, Any]) -> None: async def test_callback_async() -> None: """Test that async callbacks work correctly.""" - from pydantic_ai.agent import AgentRun - captured_runs: list[AgentRun[Any, Any]] = [] + captured_results: list[AgentRunResult[Any]] = [] - async def async_callback(agent_run: AgentRun[Any, Any]) -> None: - captured_runs.append(agent_run) + async def async_callback(run_result: AgentRunResult[Any]) -> None: + captured_results.append(run_result) agent = Agent(TestModel()) run_input = create_input( @@ -1311,17 +1310,14 @@ async def async_callback(agent_run: AgentRun[Any, Any]) -> None: ) ) - events: list[dict[str, Any]] = [] - async for event in run_ag_ui(agent, run_input, on_complete=async_callback): - events.append(json.loads(event.removeprefix('data: '))) + events = await run_and_collect_events(agent, run_input, on_complete=async_callback) # Verify callback was called - assert len(captured_runs) == 1 - agent_run = captured_runs[0] + assert len(captured_results) == 1 + run_result = captured_results[0] # Verify we can access messages - assert agent_run.result is not None, 'AgentRun result should be available in callback' - messages = agent_run.result.all_messages() + messages = run_result.all_messages() assert len(messages) >= 1 # Verify events were still streamed normally @@ -1330,46 +1326,22 @@ async def async_callback(agent_run: AgentRun[Any, Any]) -> None: assert events[-1]['type'] == 'RUN_FINISHED' -async def test_callback_none() -> None: - """Test that passing None for callback works (backwards compatibility).""" - - agent = Agent(TestModel()) - run_input = create_input( - UserMessage( - id='msg1', - content='Hello!', - ) - ) - - events: list[dict[str, Any]] = [] - async for event in run_ag_ui(agent, run_input, on_complete=None): - events.append(json.loads(event.removeprefix('data: '))) - - # Verify events were still streamed normally - assert len(events) > 0 - assert events[0]['type'] == 'RUN_STARTED' - assert events[-1]['type'] == 'RUN_FINISHED' - - async def test_callback_with_error() -> None: """Test that callbacks are not called when errors occur.""" - from pydantic_ai.agent import AgentRun - captured_runs: list[AgentRun[Any, Any]] = [] + captured_results: list[AgentRunResult[Any]] = [] - def error_callback(agent_run: AgentRun[Any, Any]) -> None: - captured_runs.append(agent_run) + def error_callback(run_result: AgentRunResult[Any]) -> None: + captured_results.append(run_result) agent = Agent(TestModel()) # Empty messages should cause an error run_input = create_input() # No messages will cause _NoMessagesError - events: list[dict[str, Any]] = [] - async for event in run_ag_ui(agent, run_input, on_complete=error_callback): - events.append(json.loads(event.removeprefix('data: '))) + events = await run_and_collect_events(agent, run_input, on_complete=error_callback) # Verify callback was not called due to error - assert len(captured_runs) == 0 + assert len(captured_results) == 0 # Verify error event was sent assert len(events) > 0 From a921cd6507a929ea79ba414a26b2e416698428fb Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 10 Sep 2025 17:57:12 +0000 Subject: [PATCH 3/3] coverage --- tests/test_ag_ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index fb3e25bef3..66fb6a81b0 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -1332,7 +1332,7 @@ async def test_callback_with_error() -> None: captured_results: list[AgentRunResult[Any]] = [] def error_callback(run_result: AgentRunResult[Any]) -> None: - captured_results.append(run_result) + captured_results.append(run_result) # pragma: no cover agent = Agent(TestModel()) # Empty messages should cause an error