Skip to content

Commit d78a3db

Browse files
ChuckJonasDouweM
andauthored
Add on_complete callback to AG-UI functions to get access to AgentRunResult (#2429)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 95b80fa commit d78a3db

File tree

2 files changed

+115
-5
lines changed

2 files changed

+115
-5
lines changed

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import json
1010
import uuid
11-
from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence
11+
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence
1212
from dataclasses import Field, dataclass, replace
1313
from http import HTTPStatus
1414
from typing import (
@@ -17,14 +17,16 @@
1717
Final,
1818
Generic,
1919
Protocol,
20+
TypeAlias,
2021
TypeVar,
2122
runtime_checkable,
2223
)
2324

2425
from pydantic import BaseModel, ValidationError
2526

27+
from . import _utils
2628
from ._agent_graph import CallToolsNode, ModelRequestNode
27-
from .agent import AbstractAgent, AgentRun
29+
from .agent import AbstractAgent, AgentRun, AgentRunResult
2830
from .exceptions import UserError
2931
from .messages import (
3032
FunctionToolResultEvent,
@@ -107,13 +109,17 @@
107109
'StateDeps',
108110
'StateHandler',
109111
'AGUIApp',
112+
'OnCompleteFunc',
110113
'handle_ag_ui_request',
111114
'run_ag_ui',
112115
]
113116

114117
SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
115118
"""Content type header value for Server-Sent Events (SSE)."""
116119

120+
OnCompleteFunc: TypeAlias = Callable[[AgentRunResult[Any]], None] | Callable[[AgentRunResult[Any]], Awaitable[None]]
121+
"""Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async."""
122+
117123

118124
class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
119125
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
@@ -220,6 +226,7 @@ async def handle_ag_ui_request(
220226
usage: RunUsage | None = None,
221227
infer_name: bool = True,
222228
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
229+
on_complete: OnCompleteFunc | None = None,
223230
) -> Response:
224231
"""Handle an AG-UI request by running the agent and returning a streaming response.
225232
@@ -236,6 +243,8 @@ async def handle_ag_ui_request(
236243
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
237244
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
238245
toolsets: Optional additional toolsets for this run.
246+
on_complete: Optional callback function called when the agent run completes successfully.
247+
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data.
239248
240249
Returns:
241250
A streaming Starlette response with AG-UI protocol events.
@@ -263,6 +272,7 @@ async def handle_ag_ui_request(
263272
usage=usage,
264273
infer_name=infer_name,
265274
toolsets=toolsets,
275+
on_complete=on_complete,
266276
),
267277
media_type=accept,
268278
)
@@ -281,6 +291,7 @@ async def run_ag_ui(
281291
usage: RunUsage | None = None,
282292
infer_name: bool = True,
283293
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
294+
on_complete: OnCompleteFunc | None = None,
284295
) -> AsyncIterator[str]:
285296
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.
286297
@@ -298,6 +309,8 @@ async def run_ag_ui(
298309
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
299310
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
300311
toolsets: Optional additional toolsets for this run.
312+
on_complete: Optional callback function called when the agent run completes successfully.
313+
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data.
301314
302315
Yields:
303316
Streaming event chunks encoded as strings according to the accept header value.
@@ -356,6 +369,12 @@ async def run_ag_ui(
356369
) as run:
357370
async for event in _agent_stream(run):
358371
yield encoder.encode(event)
372+
373+
if on_complete is not None and run.result is not None:
374+
if _utils.is_async_callable(on_complete):
375+
await on_complete(run.result)
376+
else:
377+
await _utils.run_in_executor(on_complete, run.result)
359378
except _RunError as e:
360379
yield encoder.encode(
361380
RunErrorEvent(message=e.message, code=e.code),

tests/test_ag_ui.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pydantic import BaseModel
2020

2121
from pydantic_ai._run_context import RunContext
22-
from pydantic_ai.agent import Agent
22+
from pydantic_ai.agent import Agent, AgentRunResult
2323
from pydantic_ai.exceptions import UserError
2424
from pydantic_ai.messages import ModelMessage
2525
from pydantic_ai.models.function import (
@@ -57,6 +57,7 @@
5757

5858
from pydantic_ai.ag_ui import (
5959
SSE_CONTENT_TYPE,
60+
OnCompleteFunc,
6061
StateDeps,
6162
run_ag_ui,
6263
)
@@ -96,11 +97,14 @@ def simple_result() -> Any:
9697

9798

9899
async def run_and_collect_events(
99-
agent: Agent[AgentDepsT, OutputDataT], *run_inputs: RunAgentInput, deps: AgentDepsT = None
100+
agent: Agent[AgentDepsT, OutputDataT],
101+
*run_inputs: RunAgentInput,
102+
deps: AgentDepsT = None,
103+
on_complete: OnCompleteFunc | None = None,
100104
) -> list[dict[str, Any]]:
101105
events = list[dict[str, Any]]()
102106
for run_input in run_inputs:
103-
async for event in run_ag_ui(agent, run_input, deps=deps):
107+
async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete):
104108
events.append(json.loads(event.removeprefix('data: ')))
105109
return events
106110

@@ -1256,3 +1260,90 @@ async def test_to_ag_ui() -> None:
12561260
events.append(json.loads(line.removeprefix('data: ')))
12571261

12581262
assert events == simple_result()
1263+
1264+
1265+
async def test_callback_sync() -> None:
1266+
"""Test that sync callbacks work correctly."""
1267+
1268+
captured_results: list[AgentRunResult[Any]] = []
1269+
1270+
def sync_callback(run_result: AgentRunResult[Any]) -> None:
1271+
captured_results.append(run_result)
1272+
1273+
agent = Agent(TestModel())
1274+
run_input = create_input(
1275+
UserMessage(
1276+
id='msg1',
1277+
content='Hello!',
1278+
)
1279+
)
1280+
1281+
events = await run_and_collect_events(agent, run_input, on_complete=sync_callback)
1282+
1283+
# Verify callback was called
1284+
assert len(captured_results) == 1
1285+
run_result = captured_results[0]
1286+
1287+
# Verify we can access messages
1288+
messages = run_result.all_messages()
1289+
assert len(messages) >= 1
1290+
1291+
# Verify events were still streamed normally
1292+
assert len(events) > 0
1293+
assert events[0]['type'] == 'RUN_STARTED'
1294+
assert events[-1]['type'] == 'RUN_FINISHED'
1295+
1296+
1297+
async def test_callback_async() -> None:
1298+
"""Test that async callbacks work correctly."""
1299+
1300+
captured_results: list[AgentRunResult[Any]] = []
1301+
1302+
async def async_callback(run_result: AgentRunResult[Any]) -> None:
1303+
captured_results.append(run_result)
1304+
1305+
agent = Agent(TestModel())
1306+
run_input = create_input(
1307+
UserMessage(
1308+
id='msg1',
1309+
content='Hello!',
1310+
)
1311+
)
1312+
1313+
events = await run_and_collect_events(agent, run_input, on_complete=async_callback)
1314+
1315+
# Verify callback was called
1316+
assert len(captured_results) == 1
1317+
run_result = captured_results[0]
1318+
1319+
# Verify we can access messages
1320+
messages = run_result.all_messages()
1321+
assert len(messages) >= 1
1322+
1323+
# Verify events were still streamed normally
1324+
assert len(events) > 0
1325+
assert events[0]['type'] == 'RUN_STARTED'
1326+
assert events[-1]['type'] == 'RUN_FINISHED'
1327+
1328+
1329+
async def test_callback_with_error() -> None:
1330+
"""Test that callbacks are not called when errors occur."""
1331+
1332+
captured_results: list[AgentRunResult[Any]] = []
1333+
1334+
def error_callback(run_result: AgentRunResult[Any]) -> None:
1335+
captured_results.append(run_result) # pragma: no cover
1336+
1337+
agent = Agent(TestModel())
1338+
# Empty messages should cause an error
1339+
run_input = create_input() # No messages will cause _NoMessagesError
1340+
1341+
events = await run_and_collect_events(agent, run_input, on_complete=error_callback)
1342+
1343+
# Verify callback was not called due to error
1344+
assert len(captured_results) == 0
1345+
1346+
# Verify error event was sent
1347+
assert len(events) > 0
1348+
assert events[0]['type'] == 'RUN_STARTED'
1349+
assert any(event['type'] == 'RUN_ERROR' for event in events)

0 commit comments

Comments
 (0)