|
19 | 19 | from pydantic import BaseModel
|
20 | 20 |
|
21 | 21 | from pydantic_ai._run_context import RunContext
|
22 |
| -from pydantic_ai.agent import Agent |
| 22 | +from pydantic_ai.agent import Agent, AgentRunResult |
23 | 23 | from pydantic_ai.exceptions import UserError
|
24 | 24 | from pydantic_ai.messages import ModelMessage
|
25 | 25 | from pydantic_ai.models.function import (
|
|
57 | 57 |
|
58 | 58 | from pydantic_ai.ag_ui import (
|
59 | 59 | SSE_CONTENT_TYPE,
|
| 60 | + OnCompleteFunc, |
60 | 61 | StateDeps,
|
61 | 62 | run_ag_ui,
|
62 | 63 | )
|
@@ -96,11 +97,14 @@ def simple_result() -> Any:
|
96 | 97 |
|
97 | 98 |
|
98 | 99 | async def run_and_collect_events(
|
99 |
| - agent: Agent[AgentDepsT, OutputDataT], *run_inputs: RunAgentInput, deps: AgentDepsT = None |
| 100 | + agent: Agent[AgentDepsT, OutputDataT], |
| 101 | + *run_inputs: RunAgentInput, |
| 102 | + deps: AgentDepsT = None, |
| 103 | + on_complete: OnCompleteFunc | None = None, |
100 | 104 | ) -> list[dict[str, Any]]:
|
101 | 105 | events = list[dict[str, Any]]()
|
102 | 106 | for run_input in run_inputs:
|
103 |
| - async for event in run_ag_ui(agent, run_input, deps=deps): |
| 107 | + async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete): |
104 | 108 | events.append(json.loads(event.removeprefix('data: ')))
|
105 | 109 | return events
|
106 | 110 |
|
@@ -1256,3 +1260,90 @@ async def test_to_ag_ui() -> None:
|
1256 | 1260 | events.append(json.loads(line.removeprefix('data: ')))
|
1257 | 1261 |
|
1258 | 1262 | assert events == simple_result()
|
| 1263 | + |
| 1264 | + |
| 1265 | +async def test_callback_sync() -> None: |
| 1266 | + """Test that sync callbacks work correctly.""" |
| 1267 | + |
| 1268 | + captured_results: list[AgentRunResult[Any]] = [] |
| 1269 | + |
| 1270 | + def sync_callback(run_result: AgentRunResult[Any]) -> None: |
| 1271 | + captured_results.append(run_result) |
| 1272 | + |
| 1273 | + agent = Agent(TestModel()) |
| 1274 | + run_input = create_input( |
| 1275 | + UserMessage( |
| 1276 | + id='msg1', |
| 1277 | + content='Hello!', |
| 1278 | + ) |
| 1279 | + ) |
| 1280 | + |
| 1281 | + events = await run_and_collect_events(agent, run_input, on_complete=sync_callback) |
| 1282 | + |
| 1283 | + # Verify callback was called |
| 1284 | + assert len(captured_results) == 1 |
| 1285 | + run_result = captured_results[0] |
| 1286 | + |
| 1287 | + # Verify we can access messages |
| 1288 | + messages = run_result.all_messages() |
| 1289 | + assert len(messages) >= 1 |
| 1290 | + |
| 1291 | + # Verify events were still streamed normally |
| 1292 | + assert len(events) > 0 |
| 1293 | + assert events[0]['type'] == 'RUN_STARTED' |
| 1294 | + assert events[-1]['type'] == 'RUN_FINISHED' |
| 1295 | + |
| 1296 | + |
| 1297 | +async def test_callback_async() -> None: |
| 1298 | + """Test that async callbacks work correctly.""" |
| 1299 | + |
| 1300 | + captured_results: list[AgentRunResult[Any]] = [] |
| 1301 | + |
| 1302 | + async def async_callback(run_result: AgentRunResult[Any]) -> None: |
| 1303 | + captured_results.append(run_result) |
| 1304 | + |
| 1305 | + agent = Agent(TestModel()) |
| 1306 | + run_input = create_input( |
| 1307 | + UserMessage( |
| 1308 | + id='msg1', |
| 1309 | + content='Hello!', |
| 1310 | + ) |
| 1311 | + ) |
| 1312 | + |
| 1313 | + events = await run_and_collect_events(agent, run_input, on_complete=async_callback) |
| 1314 | + |
| 1315 | + # Verify callback was called |
| 1316 | + assert len(captured_results) == 1 |
| 1317 | + run_result = captured_results[0] |
| 1318 | + |
| 1319 | + # Verify we can access messages |
| 1320 | + messages = run_result.all_messages() |
| 1321 | + assert len(messages) >= 1 |
| 1322 | + |
| 1323 | + # Verify events were still streamed normally |
| 1324 | + assert len(events) > 0 |
| 1325 | + assert events[0]['type'] == 'RUN_STARTED' |
| 1326 | + assert events[-1]['type'] == 'RUN_FINISHED' |
| 1327 | + |
| 1328 | + |
| 1329 | +async def test_callback_with_error() -> None: |
| 1330 | + """Test that callbacks are not called when errors occur.""" |
| 1331 | + |
| 1332 | + captured_results: list[AgentRunResult[Any]] = [] |
| 1333 | + |
| 1334 | + def error_callback(run_result: AgentRunResult[Any]) -> None: |
| 1335 | + captured_results.append(run_result) # pragma: no cover |
| 1336 | + |
| 1337 | + agent = Agent(TestModel()) |
| 1338 | + # Empty messages should cause an error |
| 1339 | + run_input = create_input() # No messages will cause _NoMessagesError |
| 1340 | + |
| 1341 | + events = await run_and_collect_events(agent, run_input, on_complete=error_callback) |
| 1342 | + |
| 1343 | + # Verify callback was not called due to error |
| 1344 | + assert len(captured_results) == 0 |
| 1345 | + |
| 1346 | + # Verify error event was sent |
| 1347 | + assert len(events) > 0 |
| 1348 | + assert events[0]['type'] == 'RUN_STARTED' |
| 1349 | + assert any(event['type'] == 'RUN_ERROR' for event in events) |
0 commit comments