@@ -1252,3 +1252,122 @@ async def test_to_ag_ui() -> None:
1252
1252
events .append (json .loads (line .removeprefix ('data: ' )))
1253
1253
1254
1254
assert events == simple_result ()
1255
+
1256
+
1257
+ async def test_callback_sync () -> None :
1258
+ """Test that sync callbacks work correctly."""
1259
+ from pydantic_ai .agent import AgentRun
1260
+
1261
+ captured_runs : list [AgentRun [Any , Any ]] = []
1262
+
1263
+ def sync_callback (agent_run : AgentRun [Any , Any ]) -> None :
1264
+ captured_runs .append (agent_run )
1265
+
1266
+ agent = Agent (TestModel ())
1267
+ run_input = create_input (
1268
+ UserMessage (
1269
+ id = 'msg1' ,
1270
+ content = 'Hello!' ,
1271
+ )
1272
+ )
1273
+
1274
+ events : list [dict [str , Any ]] = []
1275
+ async for event in run_ag_ui (agent , run_input , on_complete = sync_callback ):
1276
+ events .append (json .loads (event .removeprefix ('data: ' )))
1277
+
1278
+ # Verify callback was called
1279
+ assert len (captured_runs ) == 1
1280
+ agent_run = captured_runs [0 ]
1281
+
1282
+ # Verify we can access messages
1283
+ assert agent_run .result is not None , 'AgentRun result should be available in callback'
1284
+ messages = agent_run .result .all_messages ()
1285
+ assert len (messages ) >= 1
1286
+
1287
+ # Verify events were still streamed normally
1288
+ assert len (events ) > 0
1289
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1290
+ assert events [- 1 ]['type' ] == 'RUN_FINISHED'
1291
+
1292
+
1293
+ async def test_callback_async () -> None :
1294
+ """Test that async callbacks work correctly."""
1295
+ from pydantic_ai .agent import AgentRun
1296
+
1297
+ captured_runs : list [AgentRun [Any , Any ]] = []
1298
+
1299
+ async def async_callback (agent_run : AgentRun [Any , Any ]) -> None :
1300
+ captured_runs .append (agent_run )
1301
+
1302
+ agent = Agent (TestModel ())
1303
+ run_input = create_input (
1304
+ UserMessage (
1305
+ id = 'msg1' ,
1306
+ content = 'Hello!' ,
1307
+ )
1308
+ )
1309
+
1310
+ events : list [dict [str , Any ]] = []
1311
+ async for event in run_ag_ui (agent , run_input , on_complete = async_callback ):
1312
+ events .append (json .loads (event .removeprefix ('data: ' )))
1313
+
1314
+ # Verify callback was called
1315
+ assert len (captured_runs ) == 1
1316
+ agent_run = captured_runs [0 ]
1317
+
1318
+ # Verify we can access messages
1319
+ assert agent_run .result is not None , 'AgentRun result should be available in callback'
1320
+ messages = agent_run .result .all_messages ()
1321
+ assert len (messages ) >= 1
1322
+
1323
+ # Verify events were still streamed normally
1324
+ assert len (events ) > 0
1325
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1326
+ assert events [- 1 ]['type' ] == 'RUN_FINISHED'
1327
+
1328
+
1329
+ async def test_callback_none () -> None :
1330
+ """Test that passing None for callback works (backwards compatibility)."""
1331
+
1332
+ agent = Agent (TestModel ())
1333
+ run_input = create_input (
1334
+ UserMessage (
1335
+ id = 'msg1' ,
1336
+ content = 'Hello!' ,
1337
+ )
1338
+ )
1339
+
1340
+ events : list [dict [str , Any ]] = []
1341
+ async for event in run_ag_ui (agent , run_input , on_complete = None ):
1342
+ events .append (json .loads (event .removeprefix ('data: ' )))
1343
+
1344
+ # Verify events were still streamed normally
1345
+ assert len (events ) > 0
1346
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1347
+ assert events [- 1 ]['type' ] == 'RUN_FINISHED'
1348
+
1349
+
1350
+ async def test_callback_with_error () -> None :
1351
+ """Test that callbacks are not called when errors occur."""
1352
+ from pydantic_ai .agent import AgentRun
1353
+
1354
+ captured_runs : list [AgentRun [Any , Any ]] = []
1355
+
1356
+ def error_callback (agent_run : AgentRun [Any , Any ]) -> None :
1357
+ captured_runs .append (agent_run )
1358
+
1359
+ agent = Agent (TestModel ())
1360
+ # Empty messages should cause an error
1361
+ run_input = create_input () # No messages will cause _NoMessagesError
1362
+
1363
+ events : list [dict [str , Any ]] = []
1364
+ async for event in run_ag_ui (agent , run_input , on_complete = error_callback ):
1365
+ events .append (json .loads (event .removeprefix ('data: ' )))
1366
+
1367
+ # Verify callback was not called due to error
1368
+ assert len (captured_runs ) == 0
1369
+
1370
+ # Verify error event was sent
1371
+ assert len (events ) > 0
1372
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1373
+ assert any (event ['type' ] == 'RUN_ERROR' for event in events )
0 commit comments