Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
import uuid
from collections.abc import AsyncIterator, Callable, Iterable, Mapping, Sequence
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Mapping, Sequence
from dataclasses import Field, dataclass, replace
from http import HTTPStatus
from typing import (
Expand All @@ -17,14 +17,16 @@
Final,
Generic,
Protocol,
TypeAlias,
TypeVar,
runtime_checkable,
)

from pydantic import BaseModel, ValidationError

from . import _utils
from ._agent_graph import CallToolsNode, ModelRequestNode
from .agent import AbstractAgent, AgentRun
from .agent import AbstractAgent, AgentRun, AgentRunResult
from .exceptions import UserError
from .messages import (
FunctionToolResultEvent,
Expand Down Expand Up @@ -107,13 +109,17 @@
'StateDeps',
'StateHandler',
'AGUIApp',
'OnCompleteFunc',
'handle_ag_ui_request',
'run_ag_ui',
]

SSE_CONTENT_TYPE: Final[str] = 'text/event-stream'
"""Content type header value for Server-Sent Events (SSE)."""

OnCompleteFunc: TypeAlias = Callable[[AgentRunResult[Any]], None] | Callable[[AgentRunResult[Any]], Awaitable[None]]
"""Callback function type that receives the `AgentRunResult` of the completed run. Can be sync or async."""


class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette):
"""ASGI application for running Pydantic AI agents with AG-UI protocol support."""
Expand Down Expand Up @@ -220,6 +226,7 @@ async def handle_ag_ui_request(
usage: RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
on_complete: OnCompleteFunc | None = None,
) -> Response:
"""Handle an AG-UI request by running the agent and returning a streaming response.

Expand All @@ -236,6 +243,8 @@ async def handle_ag_ui_request(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
on_complete: Optional callback function called when the agent run completes successfully.
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data.

Returns:
A streaming Starlette response with AG-UI protocol events.
Expand Down Expand Up @@ -263,6 +272,7 @@ async def handle_ag_ui_request(
usage=usage,
infer_name=infer_name,
toolsets=toolsets,
on_complete=on_complete,
),
media_type=accept,
)
Expand All @@ -281,6 +291,7 @@ async def run_ag_ui(
usage: RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
on_complete: OnCompleteFunc | None = None,
) -> AsyncIterator[str]:
"""Run the agent with the AG-UI run input and stream AG-UI protocol events.

Expand All @@ -298,6 +309,8 @@ async def run_ag_ui(
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
on_complete: Optional callback function called when the agent run completes successfully.
The callback receives the completed [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] and can access `all_messages()` and other result data.

Yields:
Streaming event chunks encoded as strings according to the accept header value.
Expand Down Expand Up @@ -356,6 +369,12 @@ async def run_ag_ui(
) as run:
async for event in _agent_stream(run):
yield encoder.encode(event)

if on_complete is not None and run.result is not None:
if _utils.is_async_callable(on_complete):
await on_complete(run.result)
else:
await _utils.run_in_executor(on_complete, run.result)
except _RunError as e:
yield encoder.encode(
RunErrorEvent(message=e.message, code=e.code),
Expand Down
97 changes: 94 additions & 3 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pydantic import BaseModel

from pydantic_ai._run_context import RunContext
from pydantic_ai.agent import Agent
from pydantic_ai.agent import Agent, AgentRunResult
from pydantic_ai.exceptions import UserError
from pydantic_ai.messages import ModelMessage
from pydantic_ai.models.function import (
Expand Down Expand Up @@ -57,6 +57,7 @@

from pydantic_ai.ag_ui import (
SSE_CONTENT_TYPE,
OnCompleteFunc,
StateDeps,
run_ag_ui,
)
Expand Down Expand Up @@ -96,11 +97,14 @@ def simple_result() -> Any:


async def run_and_collect_events(
agent: Agent[AgentDepsT, OutputDataT], *run_inputs: RunAgentInput, deps: AgentDepsT = None
agent: Agent[AgentDepsT, OutputDataT],
*run_inputs: RunAgentInput,
deps: AgentDepsT = None,
on_complete: OnCompleteFunc | None = None,
) -> list[dict[str, Any]]:
events = list[dict[str, Any]]()
for run_input in run_inputs:
async for event in run_ag_ui(agent, run_input, deps=deps):
async for event in run_ag_ui(agent, run_input, deps=deps, on_complete=on_complete):
events.append(json.loads(event.removeprefix('data: ')))
return events

Expand Down Expand Up @@ -1256,3 +1260,90 @@ async def test_to_ag_ui() -> None:
events.append(json.loads(line.removeprefix('data: ')))

assert events == simple_result()


async def test_callback_sync() -> None:
"""Test that sync callbacks work correctly."""

captured_results: list[AgentRunResult[Any]] = []

def sync_callback(run_result: AgentRunResult[Any]) -> None:
captured_results.append(run_result)

agent = Agent(TestModel())
run_input = create_input(
UserMessage(
id='msg1',
content='Hello!',
)
)

events = await run_and_collect_events(agent, run_input, on_complete=sync_callback)

# Verify callback was called
assert len(captured_results) == 1
run_result = captured_results[0]

# Verify we can access messages
messages = run_result.all_messages()
assert len(messages) >= 1

# Verify events were still streamed normally
assert len(events) > 0
assert events[0]['type'] == 'RUN_STARTED'
assert events[-1]['type'] == 'RUN_FINISHED'


async def test_callback_async() -> None:
"""Test that async callbacks work correctly."""

captured_results: list[AgentRunResult[Any]] = []

async def async_callback(run_result: AgentRunResult[Any]) -> None:
captured_results.append(run_result)

agent = Agent(TestModel())
run_input = create_input(
UserMessage(
id='msg1',
content='Hello!',
)
)

events = await run_and_collect_events(agent, run_input, on_complete=async_callback)

# Verify callback was called
assert len(captured_results) == 1
run_result = captured_results[0]

# Verify we can access messages
messages = run_result.all_messages()
assert len(messages) >= 1

# Verify events were still streamed normally
assert len(events) > 0
assert events[0]['type'] == 'RUN_STARTED'
assert events[-1]['type'] == 'RUN_FINISHED'


async def test_callback_with_error() -> None:
"""Test that callbacks are not called when errors occur."""

captured_results: list[AgentRunResult[Any]] = []

def error_callback(run_result: AgentRunResult[Any]) -> None:
captured_results.append(run_result) # pragma: no cover

agent = Agent(TestModel())
# Empty messages should cause an error
run_input = create_input() # No messages will cause _NoMessagesError

events = await run_and_collect_events(agent, run_input, on_complete=error_callback)

# Verify callback was not called due to error
assert len(captured_results) == 0

# Verify error event was sent
assert len(events) > 0
assert events[0]['type'] == 'RUN_STARTED'
assert any(event['type'] == 'RUN_ERROR' for event in events)