Skip to content

Commit 9dab158

Browse files
committed
change _utils.is_async_callable iscoroutinefunction instead of
1 parent 89c9bb6 commit 9dab158

File tree

2 files changed

+140
-2
lines changed

2 files changed

+140
-2
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
from __future__ import annotations
88

9+
import inspect
910
import json
1011
import uuid
11-
from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence
12+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence
1213
from dataclasses import Field, dataclass, replace
1314
from http import HTTPStatus
1415
from typing import (
@@ -18,11 +19,14 @@
1819
Generic,
1920
Protocol,
2021
TypeVar,
22+
Union,
2123
runtime_checkable,
2224
)
2325

2426
from pydantic import BaseModel, ValidationError
2527

28+
from pydantic_ai import _utils
29+
2630
from ._agent_graph import CallToolsNode, ModelRequestNode
2731
from .agent import AbstractAgent, AgentRun
2832
from .exceptions import UserError
@@ -108,13 +112,17 @@
108112
'StateDeps',
109113
'StateHandler',
110114
'AGUIApp',
115+
'AgentRunCallback',
111116
'handle_ag_ui_request',
112117
'run_ag_ui',
113118
]
114119

115120
SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
116121
"""Content type header value for Server-Sent Events (SSE)."""
117122

123+
AgentRunCallback = Callable[[AgentRun[Any, Any]], Union[None, Awaitable[None]]]
124+
"""Callback function type that receives the completed AgentRun. Can be sync or async."""
125+
118126

119127
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
120128
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
@@ -162,7 +170,6 @@ def __init__(
162170
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
163171
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
164172
toolsets: Optional additional toolsets for this run.
165-
166173
debug: Boolean indicating if debug tracebacks should be returned on errors.
167174
routes: A list of routes to serve incoming HTTP and WebSocket requests.
168175
middleware: A list of middleware to run for every request. A starlette application will always
@@ -221,6 +228,7 @@ async def handle_ag_ui_request(
221228
usage: RunUsage | None = None,
222229
infer_name: bool = True,
223230
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
231+
on_complete: AgentRunCallback | None = None,
224232
) -> Response:
225233
"""Handle an AG-UI request by running the agent and returning a streaming response.
226234
@@ -237,6 +245,8 @@ async def handle_ag_ui_request(
237245
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
238246
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
239247
toolsets: Optional additional toolsets for this run.
248+
on_complete: Optional callback function called when the agent run completes successfully.
249+
The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.
240250
241251
Returns:
242252
A streaming Starlette response with AG-UI protocol events.
@@ -264,6 +274,7 @@ async def handle_ag_ui_request(
264274
usage=usage,
265275
infer_name=infer_name,
266276
toolsets=toolsets,
277+
on_complete=on_complete,
267278
),
268279
media_type=accept,
269280
)
@@ -282,6 +293,7 @@ async def run_ag_ui(
282293
usage: RunUsage | None = None,
283294
infer_name: bool = True,
284295
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
296+
on_complete: AgentRunCallback | None = None,
285297
) -> AsyncIterator[str]:
286298
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
287299
@@ -299,6 +311,8 @@ async def run_ag_ui(
299311
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
300312
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
301313
toolsets: Optional additional toolsets for this run.
314+
on_complete: Optional callback function called when the agent run completes successfully.
315+
The callback receives the completed [`AgentRun`][pydantic_ai.agent.AgentRun] and can access `all_messages()` and other result data.
302316
303317
Yields:
304318
Streaming event chunks encoded as strings according to the accept header value.
@@ -357,6 +371,11 @@ async def run_ag_ui(
357371
) as run:
358372
async for event in _agent_stream(run):
359373
yield encoder.encode(event)
374+
if on_complete is not None:
375+
if inspect.iscoroutinefunction(on_complete):
376+
await on_complete(run)
377+
else:
378+
await _utils.run_in_executor(on_complete, run)
360379
except _RunError as e:
361380
yield encoder.encode(
362381
RunErrorEvent(message=e.message, code=e.code),

tests/test_ag_ui.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,3 +1252,122 @@ async def test_to_ag_ui() -> None:
12521252
events.append(json.loads(line.removeprefix('data: ')))
12531253

12541254
assert events == simple_result()
1255+
1256+
1257+
async def test_callback_sync() -> None:
1258+
"""Test that sync callbacks work correctly."""
1259+
from pydantic_ai.agent import AgentRun
1260+
1261+
captured_runs: list[AgentRun[Any, Any]] = []
1262+
1263+
def sync_callback(agent_run: AgentRun[Any, Any]) -> None:
1264+
captured_runs.append(agent_run)
1265+
1266+
agent = Agent(TestModel())
1267+
run_input = create_input(
1268+
UserMessage(
1269+
id='msg1',
1270+
content='Hello!',
1271+
)
1272+
)
1273+
1274+
events: list[dict[str, Any]] = []
1275+
async for event in run_ag_ui(agent, run_input, on_complete=sync_callback):
1276+
events.append(json.loads(event.removeprefix('data: ')))
1277+
1278+
# Verify callback was called
1279+
assert len(captured_runs) == 1
1280+
agent_run = captured_runs[0]
1281+
1282+
# Verify we can access messages
1283+
assert agent_run.result is not None, 'AgentRun result should be available in callback'
1284+
messages = agent_run.result.all_messages()
1285+
assert len(messages) >= 1
1286+
1287+
# Verify events were still streamed normally
1288+
assert len(events) > 0
1289+
assert events[0]['type'] == 'RUN_STARTED'
1290+
assert events[-1]['type'] == 'RUN_FINISHED'
1291+
1292+
1293+
async def test_callback_async() -> None:
1294+
"""Test that async callbacks work correctly."""
1295+
from pydantic_ai.agent import AgentRun
1296+
1297+
captured_runs: list[AgentRun[Any, Any]] = []
1298+
1299+
async def async_callback(agent_run: AgentRun[Any, Any]) -> None:
1300+
captured_runs.append(agent_run)
1301+
1302+
agent = Agent(TestModel())
1303+
run_input = create_input(
1304+
UserMessage(
1305+
id='msg1',
1306+
content='Hello!',
1307+
)
1308+
)
1309+
1310+
events: list[dict[str, Any]] = []
1311+
async for event in run_ag_ui(agent, run_input, on_complete=async_callback):
1312+
events.append(json.loads(event.removeprefix('data: ')))
1313+
1314+
# Verify callback was called
1315+
assert len(captured_runs) == 1
1316+
agent_run = captured_runs[0]
1317+
1318+
# Verify we can access messages
1319+
assert agent_run.result is not None, 'AgentRun result should be available in callback'
1320+
messages = agent_run.result.all_messages()
1321+
assert len(messages) >= 1
1322+
1323+
# Verify events were still streamed normally
1324+
assert len(events) > 0
1325+
assert events[0]['type'] == 'RUN_STARTED'
1326+
assert events[-1]['type'] == 'RUN_FINISHED'
1327+
1328+
1329+
async def test_callback_none() -> None:
1330+
"""Test that passing None for callback works (backwards compatibility)."""
1331+
1332+
agent = Agent(TestModel())
1333+
run_input = create_input(
1334+
UserMessage(
1335+
id='msg1',
1336+
content='Hello!',
1337+
)
1338+
)
1339+
1340+
events: list[dict[str, Any]] = []
1341+
async for event in run_ag_ui(agent, run_input, on_complete=None):
1342+
events.append(json.loads(event.removeprefix('data: ')))
1343+
1344+
# Verify events were still streamed normally
1345+
assert len(events) > 0
1346+
assert events[0]['type'] == 'RUN_STARTED'
1347+
assert events[-1]['type'] == 'RUN_FINISHED'
1348+
1349+
1350+
async def test_callback_with_error() -> None:
1351+
"""Test that callbacks are not called when errors occur."""
1352+
from pydantic_ai.agent import AgentRun
1353+
1354+
captured_runs: list[AgentRun[Any, Any]] = []
1355+
1356+
def error_callback(agent_run: AgentRun[Any, Any]) -> None:
1357+
captured_runs.append(agent_run)
1358+
1359+
agent = Agent(TestModel())
1360+
# Empty messages should cause an error
1361+
run_input = create_input() # No messages will cause _NoMessagesError
1362+
1363+
events: list[dict[str, Any]] = []
1364+
async for event in run_ag_ui(agent, run_input, on_complete=error_callback):
1365+
events.append(json.loads(event.removeprefix('data: ')))
1366+
1367+
# Verify callback was not called due to error
1368+
assert len(captured_runs) == 0
1369+
1370+
# Verify error event was sent
1371+
assert len(events) > 0
1372+
assert events[0]['type'] == 'RUN_STARTED'
1373+
assert any(event['type'] == 'RUN_ERROR' for event in events)

0 commit comments

Comments
 (0)