@@ -1253,3 +1253,122 @@ async def test_to_ag_ui() -> None:
1253
1253
events .append (json .loads (line .removeprefix ('data: ' )))
1254
1254
1255
1255
assert events == simple_result ()
1256
+
1257
+
1258
+ async def test_callback_sync () -> None :
1259
+ """Test that sync callbacks work correctly."""
1260
+ from pydantic_ai .agent import AgentRun
1261
+
1262
+ captured_runs : list [AgentRun [Any , Any ]] = []
1263
+
1264
+ def sync_callback (agent_run : AgentRun [Any , Any ]) -> None :
1265
+ captured_runs .append (agent_run )
1266
+
1267
+ agent = Agent (TestModel ())
1268
+ run_input = create_input (
1269
+ UserMessage (
1270
+ id = 'msg1' ,
1271
+ content = 'Hello!' ,
1272
+ )
1273
+ )
1274
+
1275
+ events : list [dict [str , Any ]] = []
1276
+ async for event in run_ag_ui (agent , run_input , on_complete = sync_callback ):
1277
+ events .append (json .loads (event .removeprefix ('data: ' )))
1278
+
1279
+ # Verify callback was called
1280
+ assert len (captured_runs ) == 1
1281
+ agent_run = captured_runs [0 ]
1282
+
1283
+ # Verify we can access messages
1284
+ assert agent_run .result is not None , 'AgentRun result should be available in callback'
1285
+ messages = agent_run .result .all_messages ()
1286
+ assert len (messages ) >= 1
1287
+
1288
+ # Verify events were still streamed normally
1289
+ assert len (events ) > 0
1290
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1291
+ assert events [- 1 ]['type' ] == 'RUN_FINISHED'
1292
+
1293
+
1294
+ async def test_callback_async () -> None :
1295
+ """Test that async callbacks work correctly."""
1296
+ from pydantic_ai .agent import AgentRun
1297
+
1298
+ captured_runs : list [AgentRun [Any , Any ]] = []
1299
+
1300
+ async def async_callback (agent_run : AgentRun [Any , Any ]) -> None :
1301
+ captured_runs .append (agent_run )
1302
+
1303
+ agent = Agent (TestModel ())
1304
+ run_input = create_input (
1305
+ UserMessage (
1306
+ id = 'msg1' ,
1307
+ content = 'Hello!' ,
1308
+ )
1309
+ )
1310
+
1311
+ events : list [dict [str , Any ]] = []
1312
+ async for event in run_ag_ui (agent , run_input , on_complete = async_callback ):
1313
+ events .append (json .loads (event .removeprefix ('data: ' )))
1314
+
1315
+ # Verify callback was called
1316
+ assert len (captured_runs ) == 1
1317
+ agent_run = captured_runs [0 ]
1318
+
1319
+ # Verify we can access messages
1320
+ assert agent_run .result is not None , 'AgentRun result should be available in callback'
1321
+ messages = agent_run .result .all_messages ()
1322
+ assert len (messages ) >= 1
1323
+
1324
+ # Verify events were still streamed normally
1325
+ assert len (events ) > 0
1326
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1327
+ assert events [- 1 ]['type' ] == 'RUN_FINISHED'
1328
+
1329
+
1330
+ async def test_callback_none () -> None :
1331
+ """Test that passing None for callback works (backwards compatibility)."""
1332
+
1333
+ agent = Agent (TestModel ())
1334
+ run_input = create_input (
1335
+ UserMessage (
1336
+ id = 'msg1' ,
1337
+ content = 'Hello!' ,
1338
+ )
1339
+ )
1340
+
1341
+ events : list [dict [str , Any ]] = []
1342
+ async for event in run_ag_ui (agent , run_input , on_complete = None ):
1343
+ events .append (json .loads (event .removeprefix ('data: ' )))
1344
+
1345
+ # Verify events were still streamed normally
1346
+ assert len (events ) > 0
1347
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1348
+ assert events [- 1 ]['type' ] == 'RUN_FINISHED'
1349
+
1350
+
1351
+ async def test_callback_with_error () -> None :
1352
+ """Test that callbacks are not called when errors occur."""
1353
+ from pydantic_ai .agent import AgentRun
1354
+
1355
+ captured_runs : list [AgentRun [Any , Any ]] = []
1356
+
1357
+ def error_callback (agent_run : AgentRun [Any , Any ]) -> None :
1358
+ captured_runs .append (agent_run )
1359
+
1360
+ agent = Agent (TestModel ())
1361
+ # Empty messages should cause an error
1362
+ run_input = create_input () # No messages will cause _NoMessagesError
1363
+
1364
+ events : list [dict [str , Any ]] = []
1365
+ async for event in run_ag_ui (agent , run_input , on_complete = error_callback ):
1366
+ events .append (json .loads (event .removeprefix ('data: ' )))
1367
+
1368
+ # Verify callback was not called due to error
1369
+ assert len (captured_runs ) == 0
1370
+
1371
+ # Verify error event was sent
1372
+ assert len (events ) > 0
1373
+ assert events [0 ]['type' ] == 'RUN_STARTED'
1374
+ assert any (event ['type' ] == 'RUN_ERROR' for event in events )
0 commit comments