diff --git a/docs/agents.md b/docs/agents.md
index a7502a862f..72a2b89bd4 100644
--- a/docs/agents.md
+++ b/docs/agents.md
@@ -348,11 +348,15 @@ except UnexpectedModelBehavior as e:
"""
messages:
[
- UserPrompt(
- content='Please get me the volume of a box with size 6.',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ UserPromptPart(
+ content='Please get me the volume of a box with size 6.',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ )
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -364,16 +368,19 @@ except UnexpectedModelBehavior as e:
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
- RetryPrompt(
- content='Please try again.',
- tool_name='calc_volume',
- tool_call_id=None,
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='retry-prompt',
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='Please try again.',
+ tool_name='calc_volume',
+ tool_call_id=None,
+ timestamp=datetime.datetime(...),
+ part_kind='retry-prompt',
+ )
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -385,8 +392,7 @@ except UnexpectedModelBehavior as e:
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
]
"""
diff --git a/docs/api/messages.md b/docs/api/messages.md
index 9ac5c4ff43..0e06f62d02 100644
--- a/docs/api/messages.md
+++ b/docs/api/messages.md
@@ -1,15 +1,19 @@
# `pydantic_ai.messages`
+The structure of [`ModelMessage`][pydantic_ai.messages.ModelMessage] can be shown as a graph:
+
+```mermaid
+graph RL
+ SystemPromptPart(SystemPromptPart) --- ModelRequestPart
+ UserPromptPart(UserPromptPart) --- ModelRequestPart
+ ToolReturnPart(ToolReturnPart) --- ModelRequestPart
+ RetryPromptPart(RetryPromptPart) --- ModelRequestPart
+ TextPart(TextPart) --- ModelResponsePart
+ ToolCallPart(ToolCallPart) --- ModelResponsePart
+ ModelRequestPart("ModelRequestPart
(Union)") --- ModelRequest
+ ModelRequest("ModelRequest(parts=list[...])") --- ModelMessage
+ ModelResponsePart("ModelResponsePart
(Union)") --- ModelResponse
+ ModelResponse("ModelResponse(parts=list[...])") --- ModelMessage("ModelMessage
(Union)")
+```
+
::: pydantic_ai.messages
- options:
- members:
- - Message
- - SystemPrompt
- - UserPrompt
- - ToolReturn
- - RetryPrompt
- - ModelResponse
- - ToolCall
- - ArgsJson
- - ArgsObject
- - MessagesTypeAdapter
diff --git a/docs/api/models/function.md b/docs/api/models/function.md
index ddeb5dfc9b..6280ae47f1 100644
--- a/docs/api/models/function.md
+++ b/docs/api/models/function.md
@@ -11,21 +11,27 @@ Here's a minimal example:
```py {title="function_model_usage.py" call_name="test_my_agent" lint="not-imports"}
from pydantic_ai import Agent
-from pydantic_ai.messages import Message, ModelResponse
+from pydantic_ai.messages import ModelMessage, ModelResponse
from pydantic_ai.models.function import FunctionModel, AgentInfo
my_agent = Agent('openai:gpt-4o')
-async def model_function(messages: list[Message], info: AgentInfo) -> ModelResponse:
+async def model_function(
+ messages: list[ModelMessage], info: AgentInfo
+) -> ModelResponse:
print(messages)
"""
[
- UserPrompt(
- content='Testing my agent...',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ UserPromptPart(
+ content='Testing my agent...',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ )
+ ],
+ kind='request',
)
]
"""
diff --git a/docs/message-history.md b/docs/message-history.md
index 778d8a99c7..a8e94209d0 100644
--- a/docs/message-history.md
+++ b/docs/message-history.md
@@ -10,8 +10,8 @@ Both [`RunResult`][pydantic_ai.result.RunResult]
(returned by [`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync])
and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.Agent.run_stream]) have the following methods:
-* [`all_messages()`][pydantic_ai.result.RunResult.all_messages]: returns all messages, including messages from prior runs and system prompts. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.result.RunResult.all_messages_json].
-* [`new_messages()`][pydantic_ai.result.RunResult.new_messages]: returns only the messages from the current run, excluding system prompts, this is generally the data you want when you want to use the messages in further runs to continue the conversation. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.result.RunResult.new_messages_json].
+* [`all_messages()`][pydantic_ai.result.RunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.result.RunResult.all_messages_json].
+* [`new_messages()`][pydantic_ai.result.RunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.result.RunResult.new_messages_json].
!!! info "StreamedRunResult and complete messages"
On [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], the messages returned from these methods will only include the final result message once the stream has finished.
@@ -40,38 +40,18 @@ print(result.data)
print(result.all_messages())
"""
[
- SystemPrompt(
- content='Be a helpful assistant.', role='user', message_kind='system-prompt'
- ),
- UserPrompt(
- content='Tell me a joke.',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
- ),
- ModelResponse(
+ ModelRequest(
parts=[
- TextPart(
- content='Did you hear about the toothpaste scandal? They called it Colgate.',
- part_kind='text',
- )
+ SystemPromptPart(
+ content='Be a helpful assistant.', part_kind='system-prompt'
+ ),
+ UserPromptPart(
+ content='Tell me a joke.',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ ),
],
- timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
- ),
-]
-"""
-
-# messages excluding system prompts
-print(result.new_messages())
-"""
-[
- UserPrompt(
- content='Tell me a joke.',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ kind='request',
),
ModelResponse(
parts=[
@@ -81,8 +61,7 @@ print(result.new_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
]
"""
@@ -103,17 +82,19 @@ async def main():
print(result.all_messages())
"""
[
- SystemPrompt(
- content='Be a helpful assistant.',
- role='user',
- message_kind='system-prompt',
- ),
- UserPrompt(
- content='Tell me a joke.',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
- ),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(
+ content='Be a helpful assistant.', part_kind='system-prompt'
+ ),
+ UserPromptPart(
+ content='Tell me a joke.',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ ),
+ ],
+ kind='request',
+ )
]
"""
@@ -128,16 +109,18 @@ async def main():
print(result.all_messages())
"""
[
- SystemPrompt(
- content='Be a helpful assistant.',
- role='user',
- message_kind='system-prompt',
- ),
- UserPrompt(
- content='Tell me a joke.',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ SystemPromptPart(
+ content='Be a helpful assistant.', part_kind='system-prompt'
+ ),
+ UserPromptPart(
+ content='Tell me a joke.',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ ),
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -147,8 +130,7 @@ async def main():
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
]
"""
@@ -163,16 +145,7 @@ To use existing messages in a run, pass them to the `message_history` parameter
[`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync] or
[`Agent.run_stream`][pydantic_ai.Agent.run_stream].
-!!! info "`all_messages()` vs. `new_messages()`"
- PydanticAI will inspect any messages it receives for system prompts.
-
- If any system prompts are found in `message_history`, new system prompts are not generated,
- otherwise new system prompts are generated and inserted before `message_history` in the list of messages
- used in the run.
-
- Thus you can decide whether you want to use system prompts from a previous run or generate them again by using
- [`all_messages()`][pydantic_ai.result.RunResult.all_messages] or [`new_messages()`][pydantic_ai.result.RunResult.new_messages].
-
+If `message_history` is set and not empty, a new system prompt is not generated — we assume the existing message history includes a system prompt.
```python {title="Reusing messages in a conversation" hl_lines="9 13"}
from pydantic_ai import Agent
@@ -190,14 +163,18 @@ print(result2.data)
print(result2.all_messages())
"""
[
- SystemPrompt(
- content='Be a helpful assistant.', role='user', message_kind='system-prompt'
- ),
- UserPrompt(
- content='Tell me a joke.',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ SystemPromptPart(
+ content='Be a helpful assistant.', part_kind='system-prompt'
+ ),
+ UserPromptPart(
+ content='Tell me a joke.',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ ),
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -207,14 +184,17 @@ print(result2.all_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
- UserPrompt(
- content='Explain?',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ UserPromptPart(
+ content='Explain?',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ )
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -224,8 +204,7 @@ print(result2.all_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
]
"""
@@ -256,14 +235,18 @@ print(result2.data)
print(result2.all_messages())
"""
[
- SystemPrompt(
- content='Be a helpful assistant.', role='user', message_kind='system-prompt'
- ),
- UserPrompt(
- content='Tell me a joke.',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ SystemPromptPart(
+ content='Be a helpful assistant.', part_kind='system-prompt'
+ ),
+ UserPromptPart(
+ content='Tell me a joke.',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ ),
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -273,14 +256,17 @@ print(result2.all_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
- UserPrompt(
- content='Explain?',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ UserPromptPart(
+ content='Explain?',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ )
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -290,8 +276,7 @@ print(result2.all_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
]
"""
diff --git a/docs/testing-evals.md b/docs/testing-evals.md
index a0e593f8b7..b0d23e7f8c 100644
--- a/docs/testing-evals.md
+++ b/docs/testing-evals.md
@@ -100,11 +100,12 @@ from pydantic_ai.models.test import TestModel
from pydantic_ai.messages import (
ArgsDict,
ModelResponse,
- SystemPrompt,
+ SystemPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
+ ModelRequest,
)
from fake_database import DatabaseConn
@@ -125,16 +126,16 @@ async def test_forecast():
assert forecast == '{"weather_forecast":"Sunny with a chance of rain"}' # (5)!
assert weather_agent.last_run_messages == [ # (6)!
- SystemPrompt(
- content='Providing a weather forecast at the locations the user provides.',
- role='user',
- message_kind='system-prompt',
- ),
- UserPrompt(
- content='What will the weather be like in London on 2024-11-28?',
- timestamp=IsNow(tz=timezone.utc), # (7)!
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ SystemPromptPart(
+ content='Providing a weather forecast at the locations the user provides.',
+ ),
+ UserPromptPart(
+ content='What will the weather be like in London on 2024-11-28?',
+ timestamp=IsNow(tz=timezone.utc), # (7)!
+ ),
+ ]
),
ModelResponse(
parts=[
@@ -150,16 +151,16 @@ async def test_forecast():
)
],
timestamp=IsNow(tz=timezone.utc),
- role='model',
- message_kind='model-response',
),
- ToolReturn(
- tool_name='weather_forecast',
- content='Sunny with a chance of rain',
- tool_call_id=None,
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='tool-return',
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='weather_forecast',
+ content='Sunny with a chance of rain',
+ tool_call_id=None,
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ],
),
ModelResponse(
parts=[
@@ -168,8 +169,6 @@ async def test_forecast():
)
],
timestamp=IsNow(tz=timezone.utc),
- role='model',
- message_kind='model-response',
),
]
```
@@ -198,7 +197,7 @@ import pytest
from pydantic_ai import models
from pydantic_ai.messages import (
- Message,
+ ModelMessage,
ModelResponse,
ToolCallPart,
)
@@ -212,19 +211,19 @@ models.ALLOW_MODEL_REQUESTS = False
def call_weather_forecast( # (1)!
- messages: list[Message], info: AgentInfo
+ messages: list[ModelMessage], info: AgentInfo
) -> ModelResponse:
- if len(messages) == 2:
+ if len(messages) == 1:
# first call, call the weather forecast tool
- user_prompt = messages[1]
+ user_prompt = messages[0].parts[-1]
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
assert m is not None
args = {'location': 'London', 'forecast_date': m.group()} # (2)!
return ModelResponse(parts=[ToolCallPart.from_dict('weather_forecast', args)])
else:
# second call, return the forecast
- msg = messages[-1]
- assert msg.message_kind == 'tool-return'
+ msg = messages[-1].parts[0]
+ assert msg.part_kind == 'tool-return'
return ModelResponse.from_text(f'The forecast is: {msg.content}')
@@ -239,7 +238,7 @@ async def test_forecast_future():
assert forecast == 'The forecast is: Rainy with a chance of sun'
```
-1. We define a function `call_weather_forecast` that will be called by `FunctionModel` in place of the LLM, this function has access to the list of [`Message`][pydantic_ai.messages.Message]s that make up the run, and [`AgentInfo`][pydantic_ai.models.function.AgentInfo] which contains information about the agent and the function tools and return tools.
+1. We define a function `call_weather_forecast` that will be called by `FunctionModel` in place of the LLM, this function has access to the list of [`ModelMessage`][pydantic_ai.messages.ModelMessage]s that make up the run, and [`AgentInfo`][pydantic_ai.models.function.AgentInfo] which contains information about the agent and the function tools and return tools.
2. Our function is slightly intelligent in that it tries to extract a date from the prompt, but just hard codes the location.
3. We use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to replace the agent's model with our custom function.
diff --git a/docs/tools.md b/docs/tools.md
index 0680f0761f..ade96bd181 100644
--- a/docs/tools.md
+++ b/docs/tools.md
@@ -68,16 +68,19 @@ from dice_game import dice_result
print(dice_result.all_messages())
"""
[
- SystemPrompt(
- content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.",
- role='user',
- message_kind='system-prompt',
- ),
- UserPrompt(
- content='My guess is 4',
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ SystemPromptPart(
+ content="You're a dice game, you should roll the die and see if the number you get back matches the user's guess. If so, tell them they're a winner. Use the player's name in the response.",
+ part_kind='system-prompt',
+ ),
+ UserPromptPart(
+ content='My guess is 4',
+ timestamp=datetime.datetime(...),
+ part_kind='user-prompt',
+ ),
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -89,16 +92,19 @@ print(dice_result.all_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
- ToolReturn(
- tool_name='roll_die',
- content='4',
- tool_call_id=None,
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='tool-return',
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='roll_die',
+ content='4',
+ tool_call_id=None,
+ timestamp=datetime.datetime(...),
+ part_kind='tool-return',
+ )
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -110,16 +116,19 @@ print(dice_result.all_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
- ToolReturn(
- tool_name='get_player_name',
- content='Anne',
- tool_call_id=None,
- timestamp=datetime.datetime(...),
- role='user',
- message_kind='tool-return',
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_player_name',
+ content='Anne',
+ tool_call_id=None,
+ timestamp=datetime.datetime(...),
+ part_kind='tool-return',
+ )
+ ],
+ kind='request',
),
ModelResponse(
parts=[
@@ -129,8 +138,7 @@ print(dice_result.all_messages())
)
],
timestamp=datetime.datetime(...),
- role='model',
- message_kind='model-response',
+ kind='response',
),
]
"""
@@ -231,7 +239,7 @@ To demonstrate a tool's schema, here we use [`FunctionModel`][pydantic_ai.models
```python {title="tool_schema.py"}
from pydantic_ai import Agent
-from pydantic_ai.messages import Message, ModelResponse
+from pydantic_ai.messages import ModelMessage, ModelResponse
from pydantic_ai.models.function import AgentInfo, FunctionModel
agent = Agent()
@@ -249,7 +257,7 @@ def foobar(a: int, b: str, c: dict[str, list[float]]) -> str:
return f'{a} {b} {c}'
-def print_schema(messages: list[Message], info: AgentInfo) -> ModelResponse:
+def print_schema(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
tool = info.function_tools[0]
print(tool.description)
#> Get me foobar.
diff --git a/pydantic_ai_examples/chat_app.py b/pydantic_ai_examples/chat_app.py
index 0493d46e50..cb5b24c8d2 100644
--- a/pydantic_ai_examples/chat_app.py
+++ b/pydantic_ai_examples/chat_app.py
@@ -29,11 +29,11 @@
from pydantic_ai import Agent
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
- Message,
- MessagesTypeAdapter,
+ ModelMessage,
+ ModelMessagesTypeAdapter,
ModelResponse,
TextPart,
- UserPrompt,
+ UserPromptPart,
)
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
@@ -85,8 +85,8 @@ class ChatMessage(TypedDict):
content: str
-def to_chat_message(m: Message) -> ChatMessage:
- if isinstance(m, UserPrompt):
+def to_chat_message(m: ModelMessage) -> ChatMessage:
+ if isinstance(m, UserPromptPart):
return {
'role': 'user',
'timestamp': m.timestamp.isoformat(),
@@ -136,8 +136,8 @@ async def stream_messages():
return StreamingResponse(stream_messages(), media_type='text/plain')
-MessageTypeAdapter: TypeAdapter[Message] = TypeAdapter(
- Annotated[Message, Field(discriminator='message_kind')]
+MessageTypeAdapter: TypeAdapter[ModelMessage] = TypeAdapter(
+ Annotated[ModelMessage, Field(discriminator='message_kind')]
)
P = ParamSpec('P')
R = TypeVar('R')
@@ -190,14 +190,14 @@ async def add_messages(self, messages: bytes):
)
await self._asyncify(self.con.commit)
- async def get_messages(self) -> list[Message]:
+ async def get_messages(self) -> list[ModelMessage]:
c = await self._asyncify(
self._execute, 'SELECT message_list FROM messages order by id desc'
)
rows = await self._asyncify(c.fetchall)
- messages: list[Message] = []
+ messages: list[ModelMessage] = []
for row in rows:
- messages.extend(MessagesTypeAdapter.validate_json(row[0]))
+ messages.extend(ModelMessagesTypeAdapter.validate_json(row[0]))
return messages
def _execute(
diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py
index c35e622d26..2878702c94 100644
--- a/pydantic_ai_slim/pydantic_ai/_result.py
+++ b/pydantic_ai_slim/pydantic_ai/_result.py
@@ -32,7 +32,7 @@ async def validate(
deps: AgentDeps,
retry: int,
tool_call: _messages.ToolCallPart | None,
- messages: list[_messages.Message],
+ messages: list[_messages.ModelMessage],
) -> ResultData:
"""Validate a result but calling the function.
@@ -59,7 +59,7 @@ async def validate(
function = cast(Callable[[Any], ResultData], self.function)
result_data = await _utils.run_in_executor(function, *args)
except ModelRetry as r:
- m = _messages.RetryPrompt(content=r.message)
+ m = _messages.RetryPromptPart(content=r.message)
if tool_call is not None:
m.tool_name = tool_call.tool_name
m.tool_call_id = tool_call.tool_call_id
@@ -71,7 +71,7 @@ async def validate(
class ToolRetryError(Exception):
"""Internal exception used to signal a `ToolRetry` message should be returned to the LLM."""
- def __init__(self, tool_retry: _messages.RetryPrompt):
+ def __init__(self, tool_retry: _messages.RetryPromptPart):
self.tool_retry = tool_retry
super().__init__()
@@ -199,7 +199,7 @@ def validate(
)
except ValidationError as e:
if wrap_validation_errors:
- m = _messages.RetryPrompt(
+ m = _messages.RetryPromptPart(
tool_name=tool_call.tool_name,
content=e.errors(include_url=False),
tool_call_id=tool_call.tool_call_id,
diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py
index 5752880577..6b87a08291 100644
--- a/pydantic_ai_slim/pydantic_ai/_utils.py
+++ b/pydantic_ai_slim/pydantic_ai/_utils.py
@@ -15,7 +15,7 @@
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict
if TYPE_CHECKING:
- from .messages import RetryPrompt, ToolCallPart, ToolReturn
+ from .messages import RetryPromptPart, ToolCallPart, ToolReturnPart
from .tools import ObjectJsonSchema
_P = ParamSpec('_P')
@@ -257,7 +257,7 @@ def now_utc() -> datetime:
return datetime.now(tz=timezone.utc)
-def guard_tool_call_id(t: ToolCallPart | ToolReturn | RetryPrompt, model_source: str) -> str:
+def guard_tool_call_id(t: ToolCallPart | ToolReturnPart | RetryPromptPart, model_source: str) -> str:
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
return t.tool_call_id
diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py
index 83389a5021..9074aed87b 100644
--- a/pydantic_ai_slim/pydantic_ai/agent.py
+++ b/pydantic_ai_slim/pydantic_ai/agent.py
@@ -81,14 +81,14 @@ class Agent(Generic[AgentDeps, ResultData]):
end_strategy: EndStrategy
"""Strategy for handling tool calls when a final result is found."""
- model_settings: ModelSettings | None = None
+ model_settings: ModelSettings | None
"""Optional model request settings to use for this agents's runs, by default.
Note, if `model_settings` is provided by `run`, `run_sync`, or `run_stream`, those settings will
be merged with this value, with the runtime argument taking priority.
"""
- last_run_messages: list[_messages.Message] | None = None
+ last_run_messages: list[_messages.ModelMessage] | None
"""The messages from the last run, useful when a run raised an exception.
Note: these are not used by the agent, e.g. in future runs, they are just stored for developers' convenience.
@@ -161,6 +161,7 @@ def __init__(
self.end_strategy = end_strategy
self.name = name
self.model_settings = model_settings
+ self.last_run_messages = None
self._result_schema = _result.ResultSchema[result_type].build(
result_type, result_tool_name, result_tool_description
)
@@ -185,7 +186,7 @@ async def run(
self,
user_prompt: str,
*,
- message_history: list[_messages.Message] | None = None,
+ message_history: list[_messages.ModelMessage] | None = None,
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
@@ -220,6 +221,7 @@ async def run(
model_used, mode_selection = await self._get_model(model)
deps = self._get_deps(deps)
+ new_message_index = len(message_history) if message_history else 0
with _logfire.span(
'{agent_name} run {prompt=}',
@@ -229,8 +231,7 @@ async def run(
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
- self.last_run_messages = messages
+ self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
for tool in self._function_tools.values():
tool.current_retry = 0
@@ -249,16 +250,16 @@ async def run(
model_response, request_cost = await agent_model.request(messages, model_settings)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('cost', request_cost)
- model_req_span.message = f'model request -> {model_response.message_kind}'
messages.append(model_response)
cost += request_cost
with _logfire.span('handle model response', run_step=run_step) as handle_span:
- final_result, response_messages = await self._handle_model_response(model_response, deps, messages)
+ final_result, tool_responses = await self._handle_model_response(model_response, deps, messages)
- # Add all messages to the conversation
- messages.extend(response_messages)
+ if tool_responses:
+ # Add parts to the conversation as a new message
+ messages.append(_messages.ModelRequest(tool_responses))
# Check if we got a final result
if final_result is not None:
@@ -270,15 +271,15 @@ async def run(
return result.RunResult(messages, new_message_index, result_data, cost)
else:
# continue the conversation
- handle_span.set_attribute('tool_responses', response_messages)
- response_msgs = ' '.join(r.message_kind for r in response_messages)
- handle_span.message = f'handle model response -> {response_msgs}'
+ handle_span.set_attribute('tool_responses', tool_responses)
+ tool_responses = ' '.join(r.part_kind for r in tool_responses)
+ handle_span.message = f'handle model response -> {tool_responses}'
def run_sync(
self,
user_prompt: str,
*,
- message_history: list[_messages.Message] | None = None,
+ message_history: list[_messages.ModelMessage] | None = None,
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
@@ -330,7 +331,7 @@ async def run_stream(
self,
user_prompt: str,
*,
- message_history: list[_messages.Message] | None = None,
+ message_history: list[_messages.ModelMessage] | None = None,
model: models.Model | models.KnownModelName | None = None,
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
@@ -368,6 +369,7 @@ async def main():
model_used, mode_selection = await self._get_model(model)
deps = self._get_deps(deps)
+ new_message_index = len(message_history) if message_history else 0
with _logfire.span(
'{agent_name} run stream {prompt=}',
@@ -377,8 +379,7 @@ async def main():
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
- new_message_index, messages = await self._prepare_messages(deps, user_prompt, message_history)
- self.last_run_messages = messages
+ self.last_run_messages = messages = await self._prepare_messages(deps, user_prompt, message_history)
for tool in self._function_tools.values():
tool.current_retry = 0
@@ -401,12 +402,12 @@ async def main():
model_req_span.__exit__(None, None, None)
with _logfire.span('handle model response') as handle_span:
- final_result, response_messages = await self._handle_streamed_model_response(
+ final_result, new_messages = await self._handle_streamed_model_response(
model_response, deps, messages
)
# Add all messages to the conversation
- messages.extend(response_messages)
+ messages.extend(new_messages)
# Check if we got a final result
if final_result is not None:
@@ -427,9 +428,12 @@ async def main():
return
else:
# continue the conversation
- handle_span.set_attribute('tool_responses', response_messages)
- response_msgs = ' '.join(r.message_kind for r in response_messages)
- handle_span.message = f'handle model response -> {response_msgs}'
+ # the last new message should always be a model request
+ tool_response_msg = new_messages[-1]
+ assert isinstance(tool_response_msg, _messages.ModelRequest)
+ handle_span.set_attribute('tool_responses', tool_response_msg.parts)
+ tool_responses = ' '.join(r.part_kind for r in tool_response_msg.parts)
+ handle_span.message = f'handle model response -> {tool_responses}'
# the model_response should have been fully streamed by now, we can add it's cost
cost += model_response.cost()
@@ -774,7 +778,7 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
return model_, mode_selection
async def _prepare_model(
- self, model: models.Model, deps: AgentDeps, messages: list[_messages.Message]
+ self, model: models.Model, deps: AgentDeps, messages: list[_messages.ModelMessage]
) -> models.AgentModel:
"""Create building tools and create an agent model."""
function_tools: list[ToolDefinition] = []
@@ -793,28 +797,26 @@ async def add_tool(tool: Tool[AgentDeps]) -> None:
)
async def _prepare_messages(
- self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.Message] | None
- ) -> tuple[int, list[_messages.Message]]:
- # if message history includes system prompts, we don't want to regenerate them
- if message_history and any(isinstance(m, _messages.SystemPrompt) for m in message_history):
+ self, deps: AgentDeps, user_prompt: str, message_history: list[_messages.ModelMessage] | None
+ ) -> list[_messages.ModelMessage]:
+ if message_history:
# shallow copy messages
messages = message_history.copy()
+ messages.append(_messages.ModelRequest([_messages.UserPromptPart(user_prompt)]))
else:
- messages = await self._init_messages(deps)
- if message_history:
- messages += message_history
+ parts = await self._sys_parts(deps)
+ parts.append(_messages.UserPromptPart(user_prompt))
+ messages: list[_messages.ModelMessage] = [_messages.ModelRequest(parts)]
- new_message_index = len(messages)
- messages.append(_messages.UserPrompt(user_prompt))
- return new_message_index, messages
+ return messages
async def _handle_model_response(
- self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.Message]
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
+ self, model_response: _messages.ModelResponse, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
"""Process a non-streamed response from the model.
Returns:
- A tuple of `(final_result, messages)`. If `final_result` is not `None`, the conversation should end.
+ A tuple of `(final_result, request parts)`. If `final_result` is not `None`, the conversation should end.
"""
texts: list[str] = []
tool_calls: list[_messages.ToolCallPart] = []
@@ -833,8 +835,8 @@ async def _handle_model_response(
raise exceptions.UnexpectedModelBehavior('Received empty model response')
async def _handle_text_response(
- self, text: str, deps: AgentDeps, conv_messages: list[_messages.Message]
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
+ self, text: str, deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
"""Handle a plain text response from the model for non-streaming responses."""
if self._allow_text_result:
result_data_input = cast(ResultData, text)
@@ -847,14 +849,14 @@ async def _handle_text_response(
return _MarkFinalResult(result_data), []
else:
self._incr_result_retry()
- response = _messages.RetryPrompt(
+ response = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
return None, [response]
async def _handle_structured_response(
- self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.Message]
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
+ self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
assert tool_calls, 'Expected at least one tool call'
@@ -863,20 +865,20 @@ async def _handle_structured_response(
# Then process regular tools based on end strategy
if self.end_strategy == 'early' and final_result:
- tool_messages = self._mark_skipped_function_tools(tool_calls)
+ tool_parts = self._mark_skipped_function_tools(tool_calls)
else:
- tool_messages = await self._process_function_tools(tool_calls, deps, conv_messages)
+ tool_parts = await self._process_function_tools(tool_calls, deps, conv_messages)
- return final_result, [*final_messages, *tool_messages]
+ return final_result, [*final_messages, *tool_parts]
async def _process_final_tool_calls(
- self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.Message]
- ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.Message]]:
+ self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
+ ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
"""Process any final result tool calls and return the first valid result."""
if not self._result_schema:
return None, []
- messages: list[_messages.Message] = []
+ parts: list[_messages.ModelRequestPart] = []
final_result = None
for call in tool_calls:
@@ -891,11 +893,11 @@ async def _process_final_tool_calls(
result_data = await self._validate_result(result_data, deps, call, conv_messages)
except _result.ToolRetryError as e:
self._incr_result_retry()
- messages.append(e.tool_retry)
+ parts.append(e.tool_retry)
else:
final_result = _MarkFinalResult(result_data)
- messages.append(
- _messages.ToolReturn(
+ parts.append(
+ _messages.ToolReturnPart(
tool_name=call.tool_name,
content='Final result processed.',
tool_call_id=call.tool_call_id,
@@ -903,61 +905,64 @@ async def _process_final_tool_calls(
)
else:
# We already have a final result - mark this one as unused
- messages.append(
- _messages.ToolReturn(
+ parts.append(
+ _messages.ToolReturnPart(
tool_name=call.tool_name,
content='Result tool not used - a final result was already processed.',
tool_call_id=call.tool_call_id,
)
)
- return final_result, messages
+ return final_result, parts
async def _process_function_tools(
- self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.Message]
- ) -> list[_messages.Message]:
+ self, tool_calls: list[_messages.ToolCallPart], deps: AgentDeps, conv_messages: list[_messages.ModelMessage]
+ ) -> list[_messages.ModelRequestPart]:
"""Process function (non-final) tool calls in parallel."""
- messages: list[_messages.Message] = []
- tasks: list[asyncio.Task[_messages.Message]] = []
+ parts: list[_messages.ModelRequestPart] = []
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
for call in tool_calls:
if tool := self._function_tools.get(call.tool_name):
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
elif self._result_schema is None or call.tool_name not in self._result_schema.tools:
- messages.append(self._unknown_tool(call.tool_name))
+ parts.append(self._unknown_tool(call.tool_name))
# Run all tool tasks in parallel
if tasks:
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
- task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
- messages.extend(task_results)
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
+ parts.extend(task_results)
- return messages
+ return parts
def _mark_skipped_function_tools(
self,
tool_calls: list[_messages.ToolCallPart],
- ) -> list[_messages.Message]:
+ ) -> list[_messages.ModelRequestPart]:
"""Mark function tools as skipped when a final result was found with 'early' end strategy."""
- messages: list[_messages.Message] = []
+ parts: list[_messages.ModelRequestPart] = []
for call in tool_calls:
if call.tool_name in self._function_tools:
- messages.append(
- _messages.ToolReturn(
+ parts.append(
+ _messages.ToolReturnPart(
tool_name=call.tool_name,
content='Tool not executed - a final result was already processed.',
tool_call_id=call.tool_call_id,
)
)
elif self._result_schema is None or call.tool_name not in self._result_schema.tools:
- messages.append(self._unknown_tool(call.tool_name))
+ parts.append(self._unknown_tool(call.tool_name))
- return messages
+ return parts
async def _handle_streamed_model_response(
- self, model_response: models.EitherStreamedResponse, deps: AgentDeps, conv_messages: list[_messages.Message]
- ) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.Message]]:
+ self,
+ model_response: models.EitherStreamedResponse,
+ deps: AgentDeps,
+ conv_messages: list[_messages.ModelMessage],
+ ) -> tuple[_MarkFinalResult[models.EitherStreamedResponse] | None, list[_messages.ModelMessage]]:
"""Process a streamed response from the model.
Returns:
@@ -969,14 +974,14 @@ async def _handle_streamed_model_response(
return _MarkFinalResult(model_response), []
else:
self._incr_result_retry()
- response = _messages.RetryPrompt(
+ response = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please call one of the functions instead.',
)
# stream the response, so cost is correct
async for _ in model_response:
pass
- return None, [response]
+ return None, [_messages.ModelRequest([response])]
else:
assert isinstance(model_response, models.StreamStructuredResponse), f'Unexpected response: {model_response}'
if self._result_schema is not None:
@@ -992,34 +997,36 @@ async def _handle_streamed_model_response(
if match := self._result_schema.find_tool(structured_msg):
call, _ = match
- tool_return = _messages.ToolReturn(
+ tool_return = _messages.ToolReturnPart(
tool_name=call.tool_name,
content='Final result processed.',
tool_call_id=call.tool_call_id,
)
- return _MarkFinalResult(model_response), [tool_return]
+ return _MarkFinalResult(model_response), [_messages.ModelRequest([tool_return])]
# the model is calling a tool function, consume the response to get the next message
async for _ in model_response:
pass
- structured_msg = model_response.get()
- if not structured_msg.parts:
+ model_response_msg = model_response.get()
+ if not model_response_msg.parts:
raise exceptions.UnexpectedModelBehavior('Received empty tool call message')
- messages: list[_messages.Message] = [structured_msg]
+ messages: list[_messages.ModelMessage] = [model_response_msg]
# we now run all tool functions in parallel
- tasks: list[asyncio.Task[_messages.Message]] = []
- for item in structured_msg.parts:
+ tasks: list[asyncio.Task[_messages.ModelRequestPart]] = []
+ parts: list[_messages.ModelRequestPart] = []
+ for item in model_response_msg.parts:
if isinstance(item, _messages.ToolCallPart):
call = item
if tool := self._function_tools.get(call.tool_name):
tasks.append(asyncio.create_task(tool.run(deps, call, conv_messages), name=call.tool_name))
else:
- messages.append(self._unknown_tool(call.tool_name))
+ parts.append(self._unknown_tool(call.tool_name))
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
- task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks)
- messages.extend(task_results)
+ task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
+ parts.extend(task_results)
+ messages.append(_messages.ModelRequest(parts))
return None, messages
async def _validate_result(
@@ -1027,7 +1034,7 @@ async def _validate_result(
result_data: ResultData,
deps: AgentDeps,
tool_call: _messages.ToolCallPart | None,
- conv_messages: list[_messages.Message],
+ conv_messages: list[_messages.ModelMessage],
) -> ResultData:
for validator in self._result_validators:
result_data = await validator.validate(
@@ -1042,15 +1049,15 @@ def _incr_result_retry(self) -> None:
f'Exceeded maximum retries ({self._max_result_retries}) for result validation'
)
- async def _init_messages(self, deps: AgentDeps) -> list[_messages.Message]:
+ async def _sys_parts(self, deps: AgentDeps) -> list[_messages.ModelRequestPart]:
"""Build the initial messages for the conversation."""
- messages: list[_messages.Message] = [_messages.SystemPrompt(p) for p in self._system_prompts]
+ messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self._system_prompts]
for sys_prompt_runner in self._system_prompt_functions:
prompt = await sys_prompt_runner.run(deps)
- messages.append(_messages.SystemPrompt(prompt))
+ messages.append(_messages.SystemPromptPart(prompt))
return messages
- def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
+ def _unknown_tool(self, tool_name: str) -> _messages.RetryPromptPart:
self._incr_result_retry()
names = list(self._function_tools.keys())
if self._result_schema:
@@ -1059,7 +1066,7 @@ def _unknown_tool(self, tool_name: str) -> _messages.RetryPrompt:
msg = f'Available tools: {", ".join(names)}'
else:
msg = 'No tools available.'
- return _messages.RetryPrompt(content=f'Unknown tool name: {tool_name!r}. {msg}')
+ return _messages.RetryPromptPart(content=f'Unknown tool name: {tool_name!r}. {msg}')
def _get_deps(self, deps: AgentDeps) -> AgentDeps:
"""Get deps for a run.
diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py
index 5a0605bcb6..481cb4d5eb 100644
--- a/pydantic_ai_slim/pydantic_ai/messages.py
+++ b/pydantic_ai_slim/pydantic_ai/messages.py
@@ -6,12 +6,13 @@
import pydantic
import pydantic_core
+from typing_extensions import Self
from ._utils import now_utc as _now_utc
@dataclass
-class SystemPrompt:
+class SystemPromptPart:
"""A system prompt, generally written by the application developer.
This gives the model context and guidance on how to respond.
@@ -20,15 +21,12 @@ class SystemPrompt:
content: str
"""The content of the prompt."""
- role: Literal['user'] = 'user'
- """Message source identifier, this type is available on all messages as either 'user' or 'model'."""
-
- message_kind: Literal['system-prompt'] = 'system-prompt'
- """Message type identifier, this type is available on all messages as a discriminator."""
+ part_kind: Literal['system-prompt'] = 'system-prompt'
+ """Part type identifier, this is available on all parts as a discriminator."""
@dataclass
-class UserPrompt:
+class UserPromptPart:
"""A user prompt, generally written by the end user.
Content comes from the `user_prompt` parameter of [`Agent.run`][pydantic_ai.Agent.run],
@@ -41,18 +39,15 @@ class UserPrompt:
timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp of the prompt."""
- role: Literal['user'] = 'user'
- """Message source identifier, this type is available on all messages as either 'user' or 'model'."""
-
- message_kind: Literal['user-prompt'] = 'user-prompt'
- """Message type identifier, this type is available on all messages as a discriminator."""
+ part_kind: Literal['user-prompt'] = 'user-prompt'
+ """Part type identifier, this is available on all parts as a discriminator."""
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))
@dataclass
-class ToolReturn:
+class ToolReturnPart:
"""A tool return message, this encodes the result of running a tool."""
tool_name: str
@@ -67,11 +62,8 @@ class ToolReturn:
timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the tool returned."""
- role: Literal['user'] = 'user'
- """Message source identifier, this type is available on all messages as either 'user' or 'model'."""
-
- message_kind: Literal['tool-return'] = 'tool-return'
- """Message type identifier, this type is available on all messages as a discriminator."""
+ part_kind: Literal['tool-return'] = 'tool-return'
+ """Part type identifier, this is available on all parts as a discriminator."""
def model_response_str(self) -> str:
if isinstance(self.content, str):
@@ -87,11 +79,11 @@ def model_response_object(self) -> dict[str, Any]:
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
-ErrorDetailsTa = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
+error_details_ta = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))
@dataclass
-class RetryPrompt:
+class RetryPromptPart:
"""A message back to a model asking it to try again.
This can be sent for a number of reasons:
@@ -122,21 +114,35 @@ class RetryPrompt:
timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the retry was triggered."""
- role: Literal['user'] = 'user'
- """Message source identifier, this type is available on all messages as either 'user' or 'model'."""
-
- message_kind: Literal['retry-prompt'] = 'retry-prompt'
- """Message type identifier, this type is available on all messages as a discriminator."""
+ part_kind: Literal['retry-prompt'] = 'retry-prompt'
+ """Part type identifier, this is available on all parts as a discriminator."""
def model_response(self) -> str:
if isinstance(self.content, str):
description = self.content
else:
- json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
+ json_errors = error_details_ta.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2)
description = f'{len(self.content)} validation errors: {json_errors.decode()}'
return f'{description}\n\nFix the errors and try again.'
+ModelRequestPart = Annotated[
+ Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind')
+]
+"""A message part sent by PydanticAI to a model."""
+
+
+@dataclass
+class ModelRequest:
+ """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model."""
+
+ parts: list[ModelRequestPart]
+ """The parts of the user message."""
+
+ kind: Literal['request'] = 'request'
+ """Message type identifier, this is available on all parts as a discriminator."""
+
+
@dataclass
class TextPart:
"""A plain text response from a model."""
@@ -145,7 +151,7 @@ class TextPart:
"""The text content of the response."""
part_kind: Literal['text'] = 'text'
- """Part type identifier, this type is available on all message parts as a discriminator."""
+ """Part type identifier, this is available on all parts as a discriminator."""
@dataclass
@@ -166,7 +172,7 @@ class ArgsDict:
@dataclass
class ToolCallPart:
- """A tool call from the agent."""
+ """A tool call from a model."""
tool_name: str
"""The name of the tool to call."""
@@ -181,14 +187,14 @@ class ToolCallPart:
"""Optional tool call identifier, this is used by some models including OpenAI."""
part_kind: Literal['tool-call'] = 'tool-call'
- """Part type identifier, this type is available on all message parts as a discriminator."""
+ """Part type identifier, this is available on all parts as a discriminator."""
@classmethod
- def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> ToolCallPart:
+ def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> Self:
return cls(tool_name, ArgsJson(args_json), tool_call_id)
@classmethod
- def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> ToolCallPart:
+ def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> Self:
return cls(tool_name, ArgsDict(args_dict), tool_call_id)
def has_content(self) -> bool:
@@ -199,14 +205,15 @@ def has_content(self) -> bool:
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
+"""A message part returned by a model."""
@dataclass
class ModelResponse:
- """A response from a model."""
+ """A response from a model, e.g. a message from the model to the PydanticAI app."""
parts: list[ModelResponsePart]
- """The parts of the response."""
+ """The parts of the model message."""
timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp of the response.
@@ -214,25 +221,22 @@ class ModelResponse:
If the model provides a timestamp in the response (as OpenAI does) that will be used.
"""
- role: Literal['model'] = 'model'
- """Message source identifier, this type is available on all messages as either 'user' or 'model'."""
-
- message_kind: Literal['model-response'] = 'model-response'
- """Message source identifier, this type is available on all messages as a discriminator."""
+ kind: Literal['response'] = 'response'
+ """Message type identifier, this is available on all parts as a discriminator."""
@classmethod
- def from_text(cls, content: str, timestamp: datetime | None = None) -> ModelResponse:
+ def from_text(cls, content: str, timestamp: datetime | None = None) -> Self:
return cls([TextPart(content)], timestamp=timestamp or _now_utc())
@classmethod
- def from_tool_call(cls, tool_call: ToolCallPart) -> ModelResponse:
+ def from_tool_call(cls, tool_call: ToolCallPart) -> Self:
return cls([tool_call])
-Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelResponse]
+ModelMessage = Union[ModelRequest, ModelResponse]
"""Any message send to or returned by a model."""
-MessagesTypeAdapter = pydantic.TypeAdapter(
- list[Annotated[Message, pydantic.Discriminator('message_kind')]], config=pydantic.ConfigDict(defer_build=True)
+ModelMessagesTypeAdapter = pydantic.TypeAdapter(
+ list[Annotated[ModelMessage, pydantic.Discriminator('kind')]], config=pydantic.ConfigDict(defer_build=True)
)
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py
index 1bc0a9db81..0291e6bcd7 100644
--- a/pydantic_ai_slim/pydantic_ai/models/__init__.py
+++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py
@@ -16,7 +16,7 @@
import httpx
from ..exceptions import UserError
-from ..messages import Message, ModelResponse
+from ..messages import ModelMessage, ModelResponse
from ..settings import ModelSettings
if TYPE_CHECKING:
@@ -121,14 +121,14 @@ class AgentModel(ABC):
@abstractmethod
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, Cost]:
"""Make a request to the model."""
raise NotImplementedError()
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
"""Make a request to the model and return a streaming response."""
raise NotImplementedError(f'Streamed requests not supported by this {self.__class__.__name__}')
diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py
index 808536fb76..4c4c43841c 100644
--- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py
+++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py
@@ -6,21 +6,22 @@
from typing import Any, Literal, Union, cast, overload
from httpx import AsyncClient as AsyncHTTPClient
+from typing_extensions import assert_never
from .. import result
from .._utils import guard_tool_call_id as _guard_tool_call_id
-from ..exceptions import UnexpectedModelBehavior
from ..messages import (
ArgsDict,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
ModelResponsePart,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from ..settings import ModelSettings
from ..tools import ToolDefinition
@@ -156,14 +157,14 @@ class AnthropicAgentModel(AgentModel):
tools: list[ToolParam]
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Cost]:
response = await self._messages_create(messages, False, model_settings)
return self._process_response(response), _map_cost(response)
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
response = await self._messages_create(messages, True, model_settings)
async with response:
@@ -171,18 +172,18 @@ async def request_stream(
@overload
async def _messages_create(
- self, messages: list[Message], stream: Literal[True], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
) -> AsyncStream[RawMessageStreamEvent]:
pass
@overload
async def _messages_create(
- self, messages: list[Message], stream: Literal[False], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
) -> AnthropicMessage:
pass
async def _messages_create(
- self, messages: list[Message], stream: bool, model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
) -> AnthropicMessage | AsyncStream[RawMessageStreamEvent]:
# standalone function to make it easier to override
if not self.tools:
@@ -192,14 +193,7 @@ async def _messages_create(
else:
tool_choice = {'type': 'auto'}
- system_prompt: str = ''
- anthropic_messages: list[MessageParam] = []
-
- for m in messages:
- if isinstance(m, SystemPrompt):
- system_prompt += m.content
- else:
- anthropic_messages.append(self._map_message(m))
+ system_prompt, anthropic_messages = self._map_message(messages)
model_settings = model_settings or {}
@@ -255,51 +249,60 @@ async def _process_streamed_response(response: AsyncStream[RawMessageStreamEvent
# We might refactor streaming internally before we implement this...
@staticmethod
- def _map_message(message: Message) -> MessageParam:
+ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]:
"""Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`."""
- if isinstance(message, UserPrompt):
- return MessageParam(role='user', content=message.content)
- elif isinstance(message, ToolReturn):
- return MessageParam(
- role='user',
- content=[
- ToolResultBlockParam(
- tool_use_id=_guard_tool_call_id(t=message, model_source='Anthropic'),
- type='tool_result',
- content=message.model_response_str(),
- is_error=False,
- )
- ],
- )
- elif isinstance(message, RetryPrompt):
- if message.tool_name is None:
- return MessageParam(role='user', content=message.model_response())
+ system_prompt: str = ''
+ anthropic_messages: list[MessageParam] = []
+ for m in messages:
+ if isinstance(m, ModelRequest):
+ for part in m.parts:
+ if isinstance(part, SystemPromptPart):
+ system_prompt += part.content
+ elif isinstance(part, UserPromptPart):
+ anthropic_messages.append(MessageParam(role='user', content=part.content))
+ elif isinstance(part, ToolReturnPart):
+ anthropic_messages.append(
+ MessageParam(
+ role='user',
+ content=[
+ ToolResultBlockParam(
+ tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'),
+ type='tool_result',
+ content=part.model_response_str(),
+ is_error=False,
+ )
+ ],
+ )
+ )
+ elif isinstance(part, RetryPromptPart):
+ if part.tool_name is None:
+ anthropic_messages.append(MessageParam(role='user', content=part.model_response()))
+ else:
+ anthropic_messages.append(
+ MessageParam(
+ role='user',
+ content=[
+ ToolUseBlockParam(
+ id=_guard_tool_call_id(t=part, model_source='Anthropic'),
+ input=part.model_response(),
+ name=part.tool_name,
+ type='tool_use',
+ ),
+ ],
+ )
+ )
+ elif isinstance(m, ModelResponse):
+ content: list[TextBlockParam | ToolUseBlockParam] = []
+ for item in m.parts:
+ if isinstance(item, TextPart):
+ content.append(TextBlockParam(text=item.content, type='text'))
+ else:
+ assert isinstance(item, ToolCallPart)
+ content.append(_map_tool_call(item))
+ anthropic_messages.append(MessageParam(role='assistant', content=content))
else:
- return MessageParam(
- role='user',
- content=[
- ToolUseBlockParam(
- id=_guard_tool_call_id(t=message, model_source='Anthropic'),
- input=message.model_response(),
- name=message.tool_name,
- type='tool_use',
- ),
- ],
- )
- elif isinstance(message, ModelResponse):
- content: list[TextBlockParam | ToolUseBlockParam] = []
- for item in message.parts:
- if isinstance(item, TextPart):
- content.append(TextBlockParam(text=item.content, type='text'))
- else:
- assert isinstance(item, ToolCallPart)
- content.append(_map_tool_call(item))
- return MessageParam(role='assistant', content=content)
- else:
- assert isinstance(message, SystemPrompt)
- raise UnexpectedModelBehavior(
- 'System messages are handled separately for Anthropic, this is a bug, please report it.'
- )
+ assert_never(m)
+ return system_prompt, anthropic_messages
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py
index e9dfc55c7d..007538f4dd 100644
--- a/pydantic_ai_slim/pydantic_ai/models/function.py
+++ b/pydantic_ai_slim/pydantic_ai/models/function.py
@@ -15,15 +15,16 @@
from .. import _utils, result
from ..messages import (
ArgsJson,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
ModelResponsePart,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from ..settings import ModelSettings
from ..tools import ToolDefinition
@@ -120,10 +121,10 @@ class DeltaToolCall:
DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall]
"""A mapping of tool call IDs to incremental changes."""
-FunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
+FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]]
"""A function used to generate a non-streamed response."""
-StreamFunctionDef: TypeAlias = Callable[[list[Message], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
+StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]]
"""A function used to generate a streamed response.
While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should
@@ -142,7 +143,7 @@ class FunctionAgentModel(AgentModel):
agent_info: AgentInfo
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Cost]:
agent_info = replace(self.agent_info, model_settings=model_settings)
@@ -158,7 +159,7 @@ async def request(
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
assert (
self.stream_function is not None
@@ -242,35 +243,38 @@ def timestamp(self) -> datetime:
return self._timestamp
-def _estimate_cost(messages: Iterable[Message]) -> result.Cost:
+def _estimate_cost(messages: Iterable[ModelMessage]) -> result.Cost:
"""Very rough guesstimate of the number of tokens associate with a series of messages.
This is designed to be used solely to give plausible numbers for testing!
"""
# there seem to be about 50 tokens of overhead for both Gemini and OpenAI calls, so add that here ¯\_(ツ)_/¯
-
request_tokens = 50
response_tokens = 0
for message in messages:
- if isinstance(message, (SystemPrompt, UserPrompt)):
- request_tokens += _string_cost(message.content)
- elif isinstance(message, ToolReturn):
- request_tokens += _string_cost(message.model_response_str())
- elif isinstance(message, RetryPrompt):
- request_tokens += _string_cost(message.model_response())
+ if isinstance(message, ModelRequest):
+ for part in message.parts:
+ if isinstance(part, (SystemPromptPart, UserPromptPart)):
+ request_tokens += _string_cost(part.content)
+ elif isinstance(part, ToolReturnPart):
+ request_tokens += _string_cost(part.model_response_str())
+ elif isinstance(part, RetryPromptPart):
+ request_tokens += _string_cost(part.model_response())
+ else:
+ assert_never(part)
elif isinstance(message, ModelResponse):
- for item in message.parts:
- if isinstance(item, TextPart):
- response_tokens += _string_cost(item.content)
- elif isinstance(item, ToolCallPart):
- call = item
+ for part in message.parts:
+ if isinstance(part, TextPart):
+ response_tokens += _string_cost(part.content)
+ elif isinstance(part, ToolCallPart):
+ call = part
if isinstance(call.args, ArgsJson):
args_str = call.args.args_json
else:
args_str = pydantic_core.to_json(call.args.args_dict).decode()
response_tokens += 1 + _string_cost(args_str)
else:
- assert_never(item)
+ assert_never(part)
else:
assert_never(message)
return result.Cost(
diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py
index 7705a1b716..7290642800 100644
--- a/pydantic_ai_slim/pydantic_ai/models/gemini.py
+++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py
@@ -17,15 +17,16 @@
from .. import UnexpectedModelBehavior, _utils, exceptions, result
from ..messages import (
ArgsDict,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
ModelResponsePart,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from ..settings import ModelSettings
from ..tools import ToolDefinition
@@ -170,7 +171,7 @@ def __init__(
self.url = url
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Cost]:
async with self._make_request(messages, False, model_settings) as http_response:
response = _gemini_response_ta.validate_json(await http_response.aread())
@@ -178,22 +179,16 @@ async def request(
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
async with self._make_request(messages, True, model_settings) as http_response:
yield await self._process_streamed_response(http_response)
@asynccontextmanager
async def _make_request(
- self, messages: list[Message], streamed: bool, model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], streamed: bool, model_settings: ModelSettings | None
) -> AsyncIterator[HTTPResponse]:
- sys_prompt_parts: list[_GeminiTextPart] = []
- contents: list[_GeminiContent] = []
- for m in messages:
- if (sys_prompt := self._message_to_gemini_system_prompt(m)) is not None:
- sys_prompt_parts.append(sys_prompt)
- if (content := self._message_to_gemini_content(m)) is not None:
- contents.append(content)
+ sys_prompt_parts, contents = self._message_to_gemini_content(messages)
request_data = _GeminiRequest(contents=contents)
if sys_prompt_parts:
@@ -271,28 +266,31 @@ async def _process_streamed_response(http_response: HTTPResponse) -> EitherStrea
else:
return GeminiStreamTextResponse(_json_content=content, _stream=aiter_bytes)
- @staticmethod
- def _message_to_gemini_system_prompt(m: Message) -> _GeminiTextPart | None:
- """Convert a message to a _GeminiTextPart for "system_instructions" or _GeminiContent for "contents"."""
- if isinstance(m, SystemPrompt):
- return _GeminiTextPart(text=m.content)
- else:
- return None
+ @classmethod
+ def _message_to_gemini_content(
+ cls, messages: list[ModelMessage]
+ ) -> tuple[list[_GeminiTextPart], list[_GeminiContent]]:
+ sys_prompt_parts: list[_GeminiTextPart] = []
+ contents: list[_GeminiContent] = []
+ for m in messages:
+ if isinstance(m, ModelRequest):
+ for part in m.parts:
+ if isinstance(part, SystemPromptPart):
+ sys_prompt_parts.append(_GeminiTextPart(text=part.content))
+ elif isinstance(part, UserPromptPart):
+ contents.append(_content_user_prompt(part))
+ elif isinstance(part, ToolReturnPart):
+ contents.append(_content_tool_return(part))
+ elif isinstance(part, RetryPromptPart):
+ contents.append(_content_retry_prompt(part))
+ else:
+ assert_never(part)
+ elif isinstance(m, ModelResponse):
+ contents.append(_content_model_response(m))
+ else:
+ assert_never(m)
- @staticmethod
- def _message_to_gemini_content(m: Message) -> _GeminiContent | None:
- if isinstance(m, SystemPrompt):
- return None
- elif isinstance(m, UserPrompt):
- return _content_user_prompt(m)
- elif isinstance(m, ToolReturn):
- return _content_tool_return(m)
- elif isinstance(m, RetryPrompt):
- return _content_retry_prompt(m)
- elif isinstance(m, ModelResponse):
- return _content_model_response(m)
- else:
- assert_never(m)
+ return sys_prompt_parts, contents
@dataclass
@@ -423,16 +421,16 @@ class _GeminiContent(TypedDict):
parts: list[_GeminiPartUnion]
-def _content_user_prompt(m: UserPrompt) -> _GeminiContent:
+def _content_user_prompt(m: UserPromptPart) -> _GeminiContent:
return _GeminiContent(role='user', parts=[_GeminiTextPart(text=m.content)])
-def _content_tool_return(m: ToolReturn) -> _GeminiContent:
+def _content_tool_return(m: ToolReturnPart) -> _GeminiContent:
f_response = _response_part_from_response(m.tool_name, m.model_response_object())
return _GeminiContent(role='user', parts=[f_response])
-def _content_retry_prompt(m: RetryPrompt) -> _GeminiContent:
+def _content_retry_prompt(m: RetryPromptPart) -> _GeminiContent:
if m.tool_name is None:
part = _GeminiTextPart(text=m.model_response())
else:
diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py
index 66f3ca9d76..20be8a3390 100644
--- a/pydantic_ai_slim/pydantic_ai/models/groq.py
+++ b/pydantic_ai_slim/pydantic_ai/models/groq.py
@@ -4,6 +4,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
+from itertools import chain
from typing import Literal, overload
from httpx import AsyncClient as AsyncHTTPClient
@@ -13,15 +14,16 @@
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ArgsJson,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
ModelResponsePart,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from ..result import Cost
from ..settings import ModelSettings
@@ -155,14 +157,14 @@ class GroqAgentModel(AgentModel):
tools: list[chat.ChatCompletionToolParam]
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Cost]:
response = await self._completions_create(messages, False, model_settings)
return self._process_response(response), _map_cost(response)
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
response = await self._completions_create(messages, True, model_settings)
async with response:
@@ -170,18 +172,18 @@ async def request_stream(
@overload
async def _completions_create(
- self, messages: list[Message], stream: Literal[True], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
) -> AsyncStream[ChatCompletionChunk]:
pass
@overload
async def _completions_create(
- self, messages: list[Message], stream: Literal[False], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
) -> chat.ChatCompletion:
pass
async def _completions_create(
- self, messages: list[Message], stream: bool, model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
# standalone function to make it easier to override
if not self.tools:
@@ -191,7 +193,7 @@ async def _completions_create(
else:
tool_choice = 'auto'
- groq_messages = [self._map_message(m) for m in messages]
+ groq_messages = list(chain(*(self._map_message(m) for m in messages)))
model_settings = model_settings or {}
@@ -249,28 +251,11 @@ async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk])
start_cost,
)
- @staticmethod
- def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
+ @classmethod
+ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
"""Just maps a `pydantic_ai.Message` to a `groq.types.ChatCompletionMessageParam`."""
- if isinstance(message, SystemPrompt):
- return chat.ChatCompletionSystemMessageParam(role='system', content=message.content)
- elif isinstance(message, UserPrompt):
- return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
- elif isinstance(message, ToolReturn):
- return chat.ChatCompletionToolMessageParam(
- role='tool',
- tool_call_id=_guard_tool_call_id(t=message, model_source='Groq'),
- content=message.model_response_str(),
- )
- elif isinstance(message, RetryPrompt):
- if message.tool_name is None:
- return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
- else:
- return chat.ChatCompletionToolMessageParam(
- role='tool',
- tool_call_id=_guard_tool_call_id(t=message, model_source='Groq'),
- content=message.model_response(),
- )
+ if isinstance(message, ModelRequest):
+ yield from cls._map_user_message(message)
elif isinstance(message, ModelResponse):
texts: list[str] = []
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -288,10 +273,33 @@ def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
message_param['content'] = '\n\n'.join(texts)
if tool_calls:
message_param['tool_calls'] = tool_calls
- return message_param
+ yield message_param
else:
assert_never(message)
+ @classmethod
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
+ for part in message.parts:
+ if isinstance(part, SystemPromptPart):
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
+ elif isinstance(part, UserPromptPart):
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
+ elif isinstance(part, ToolReturnPart):
+ yield chat.ChatCompletionToolMessageParam(
+ role='tool',
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
+ content=part.model_response_str(),
+ )
+ elif isinstance(part, RetryPromptPart):
+ if part.tool_name is None:
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
+ else:
+ yield chat.ChatCompletionToolMessageParam(
+ role='tool',
+ tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
+ content=part.model_response(),
+ )
+
@dataclass
class GroqStreamTextResponse(StreamTextResponse):
diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py
index aa843bf9b0..90525869ca 100644
--- a/pydantic_ai_slim/pydantic_ai/models/mistral.py
+++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py
@@ -5,6 +5,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
+from itertools import chain
from typing import Any, Callable, Literal, Union
from httpx import AsyncClient as AsyncHTTPClient, Timeout
@@ -14,15 +15,16 @@
from .._utils import now_utc as _now_utc
from ..messages import (
ArgsJson,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
ModelResponsePart,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from ..result import Cost
from ..settings import ModelSettings
@@ -153,7 +155,7 @@ class MistralAgentModel(AgentModel):
json_mode_schema_prompt: str = """Answer in JSON Object, respect this following format:\n```\n{schema}\n```\n"""
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, Cost]:
"""Make a non-streaming request to the model from Pydantic AI call."""
response = await self._completions_create(messages, model_settings)
@@ -161,7 +163,7 @@ async def request(
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
"""Make a streaming request to the model from Pydantic AI call."""
response = await self._stream_completions_create(messages, model_settings)
@@ -170,13 +172,13 @@ async def request_stream(
yield await self._process_streamed_response(is_function_tool, self.result_tools, response)
async def _completions_create(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> MistralChatCompletionResponse:
"""Make a non-streaming request to the model."""
model_settings = model_settings or {}
response = await self.client.chat.complete_async(
model=str(self.model_name),
- messages=[self._map_message(m) for m in messages],
+ messages=list(chain(*(self._map_message(m) for m in messages))),
n=1,
tools=self._map_function_and_result_tools_definition() or UNSET,
tool_choice=self._get_tool_choice(),
@@ -191,12 +193,12 @@ async def _completions_create(
async def _stream_completions_create(
self,
- messages: list[Message],
+ messages: list[ModelMessage],
model_settings: ModelSettings | None,
) -> MistralEventStreamAsync[MistralCompletionEvent]:
"""Create a streaming completion request to the Mistral model."""
response: MistralEventStreamAsync[MistralCompletionEvent] | None
- mistral_messages = [self._map_message(m) for m in messages]
+ mistral_messages = list(chain(*(self._map_message(m) for m in messages)))
model_settings = model_settings or {}
@@ -361,26 +363,11 @@ async def _process_streamed_response(
start_cost,
)
- @staticmethod
- def _map_message(message: Message) -> MistralMessages:
+ @classmethod
+ def _map_message(cls, message: ModelMessage) -> Iterable[MistralMessages]:
"""Just maps a `pydantic_ai.Message` to a `Mistral Message`."""
- if isinstance(message, SystemPrompt):
- return MistralSystemMessage(content=message.content)
- elif isinstance(message, UserPrompt):
- return MistralUserMessage(content=message.content)
- elif isinstance(message, ToolReturn):
- return MistralToolMessage(
- tool_call_id=message.tool_call_id,
- content=message.model_response_str(),
- )
- elif isinstance(message, RetryPrompt):
- if message.tool_name is None:
- return MistralUserMessage(content=message.model_response())
- else:
- return MistralToolMessage(
- tool_call_id=message.tool_call_id,
- content=message.model_response(),
- )
+ if isinstance(message, ModelRequest):
+ yield from cls._map_user_message(message)
elif isinstance(message, ModelResponse):
content_chunks: list[MistralContentChunk] = []
tool_calls: list[MistralToolCall] = []
@@ -392,10 +379,33 @@ def _map_message(message: Message) -> MistralMessages:
tool_calls.append(_map_pydantic_to_mistral_tool_call(part))
else:
assert_never(part)
- return MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
+ yield MistralAssistantMessage(content=content_chunks, tool_calls=tool_calls)
else:
assert_never(message)
+ @classmethod
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[MistralMessages]:
+ for part in message.parts:
+ if isinstance(part, SystemPromptPart):
+ yield MistralSystemMessage(content=part.content)
+ elif isinstance(part, UserPromptPart):
+ yield MistralUserMessage(content=part.content)
+ elif isinstance(part, ToolReturnPart):
+ yield MistralToolMessage(
+ tool_call_id=part.tool_call_id,
+ content=part.model_response_str(),
+ )
+ elif isinstance(part, RetryPromptPart):
+ if part.tool_name is None:
+ yield MistralUserMessage(content=part.model_response())
+ else:
+ yield MistralToolMessage(
+ tool_call_id=part.tool_call_id,
+ content=part.model_response(),
+ )
+ else:
+ assert_never(part)
+
@dataclass
class MistralStreamTextResponse(StreamTextResponse):
diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py
index cdb0a1d137..957595987c 100644
--- a/pydantic_ai_slim/pydantic_ai/models/openai.py
+++ b/pydantic_ai_slim/pydantic_ai/models/openai.py
@@ -4,6 +4,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timezone
+from itertools import chain
from typing import Literal, Union, overload
from httpx import AsyncClient as AsyncHTTPClient
@@ -13,15 +14,16 @@
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ArgsJson,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
ModelResponsePart,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from ..result import Cost
from ..settings import ModelSettings
@@ -144,14 +146,14 @@ class OpenAIAgentModel(AgentModel):
tools: list[chat.ChatCompletionToolParam]
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, result.Cost]:
response = await self._completions_create(messages, False, model_settings)
return self._process_response(response), _map_cost(response)
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
response = await self._completions_create(messages, True, model_settings)
async with response:
@@ -159,18 +161,18 @@ async def request_stream(
@overload
async def _completions_create(
- self, messages: list[Message], stream: Literal[True], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: Literal[True], model_settings: ModelSettings | None
) -> AsyncStream[ChatCompletionChunk]:
pass
@overload
async def _completions_create(
- self, messages: list[Message], stream: Literal[False], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: Literal[False], model_settings: ModelSettings | None
) -> chat.ChatCompletion:
pass
async def _completions_create(
- self, messages: list[Message], stream: bool, model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], stream: bool, model_settings: ModelSettings | None
) -> chat.ChatCompletion | AsyncStream[ChatCompletionChunk]:
# standalone function to make it easier to override
if not self.tools:
@@ -180,7 +182,7 @@ async def _completions_create(
else:
tool_choice = 'auto'
- openai_messages = [self._map_message(m) for m in messages]
+ openai_messages = list(chain(*(self._map_message(m) for m in messages)))
model_settings = model_settings or {}
@@ -241,28 +243,11 @@ async def _process_streamed_response(response: AsyncStream[ChatCompletionChunk])
)
# else continue until we get either delta.content or delta.tool_calls
- @staticmethod
- def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
+ @classmethod
+ def _map_message(cls, message: ModelMessage) -> Iterable[chat.ChatCompletionMessageParam]:
"""Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`."""
- if isinstance(message, SystemPrompt):
- return chat.ChatCompletionSystemMessageParam(role='system', content=message.content)
- elif isinstance(message, UserPrompt):
- return chat.ChatCompletionUserMessageParam(role='user', content=message.content)
- elif isinstance(message, ToolReturn):
- return chat.ChatCompletionToolMessageParam(
- role='tool',
- tool_call_id=_guard_tool_call_id(t=message, model_source='OpenAI'),
- content=message.model_response_str(),
- )
- elif isinstance(message, RetryPrompt):
- if message.tool_name is None:
- return chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
- else:
- return chat.ChatCompletionToolMessageParam(
- role='tool',
- tool_call_id=_guard_tool_call_id(t=message, model_source='OpenAI'),
- content=message.model_response(),
- )
+ if isinstance(message, ModelRequest):
+ yield from cls._map_user_message(message)
elif isinstance(message, ModelResponse):
texts: list[str] = []
tool_calls: list[chat.ChatCompletionMessageToolCallParam] = []
@@ -280,10 +265,35 @@ def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
message_param['content'] = '\n\n'.join(texts)
if tool_calls:
message_param['tool_calls'] = tool_calls
- return message_param
+ yield message_param
else:
assert_never(message)
+ @classmethod
+ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletionMessageParam]:
+ for part in message.parts:
+ if isinstance(part, SystemPromptPart):
+ yield chat.ChatCompletionSystemMessageParam(role='system', content=part.content)
+ elif isinstance(part, UserPromptPart):
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.content)
+ elif isinstance(part, ToolReturnPart):
+ yield chat.ChatCompletionToolMessageParam(
+ role='tool',
+ tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
+ content=part.model_response_str(),
+ )
+ elif isinstance(part, RetryPromptPart):
+ if part.tool_name is None:
+ yield chat.ChatCompletionUserMessageParam(role='user', content=part.model_response())
+ else:
+ yield chat.ChatCompletionToolMessageParam(
+ role='tool',
+ tool_call_id=_guard_tool_call_id(t=part, model_source='OpenAI'),
+ content=part.model_response(),
+ )
+ else:
+ assert_never(part)
+
@dataclass
class OpenAIStreamTextResponse(StreamTextResponse):
diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py
index 77b63739a8..4167091447 100644
--- a/pydantic_ai_slim/pydantic_ai/models/test.py
+++ b/pydantic_ai_slim/pydantic_ai/models/test.py
@@ -13,12 +13,13 @@
from .. import _utils
from ..messages import (
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
+ RetryPromptPart,
TextPart,
ToolCallPart,
- ToolReturn,
+ ToolReturnPart,
)
from ..result import Cost
from ..settings import ModelSettings
@@ -129,13 +130,13 @@ class TestAgentModel(AgentModel):
seed: int
async def request(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> tuple[ModelResponse, Cost]:
return self._request(messages, model_settings), Cost()
@asynccontextmanager
async def request_stream(
- self, messages: list[Message], model_settings: ModelSettings | None
+ self, messages: list[ModelMessage], model_settings: ModelSettings | None
) -> AsyncIterator[EitherStreamedResponse]:
msg = self._request(messages, model_settings)
cost = Cost()
@@ -159,34 +160,37 @@ async def request_stream(
def gen_tool_args(self, tool_def: ToolDefinition) -> Any:
return _JsonSchemaTestData(tool_def.parameters_json_schema, self.seed).generate()
- def _request(self, messages: list[Message], model_settings: ModelSettings | None) -> ModelResponse:
+ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings | None) -> ModelResponse:
# if there are tools, the first thing we want to do is call all of them
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
return ModelResponse(
parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
)
- # get messages since the last model response
- new_messages = _get_new_messages(messages)
-
- # check if there are any retry prompts, if so retry them
- new_retry_names = {m.tool_name for m in new_messages if isinstance(m, RetryPrompt)}
- if new_retry_names:
- return ModelResponse(
- parts=[
- ToolCallPart.from_dict(name, self.gen_tool_args(args))
- for name, args in self.tool_calls
- if name in new_retry_names
- ]
- )
+ if messages:
+ last_message = messages[-1]
+ assert isinstance(last_message, ModelRequest), 'Expected last message to be a `ModelRequest`.'
+
+ # check if there are any retry prompts, if so retry them
+ new_retry_names = {p.tool_name for p in last_message.parts if isinstance(p, RetryPromptPart)}
+ if new_retry_names:
+ return ModelResponse(
+ parts=[
+ ToolCallPart.from_dict(name, self.gen_tool_args(args))
+ for name, args in self.tool_calls
+ if name in new_retry_names
+ ]
+ )
if response_text := self.result.left:
if response_text.value is None:
# build up details of tool responses
output: dict[str, Any] = {}
for message in messages:
- if isinstance(message, ToolReturn):
- output[message.tool_name] = message.content
+ if isinstance(message, ModelRequest):
+ for part in message.parts:
+ if isinstance(part, ToolReturnPart):
+ output[part.tool_name] = part.content
if output:
return ModelResponse.from_text(pydantic_core.to_json(output).decode())
else:
@@ -204,18 +208,6 @@ def _request(self, messages: list[Message], model_settings: ModelSettings | None
return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
-def _get_new_messages(messages: list[Message]) -> list[Message]:
- last_model_index = None
- for i, m in enumerate(messages):
- if isinstance(m, ModelResponse):
- last_model_index = i
-
- if last_model_index is not None:
- return messages[last_model_index + 1 :]
- else:
- return []
-
-
@dataclass
class TestStreamTextResponse(StreamTextResponse):
"""A text response that streams test data."""
diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py
index 4f382b8b85..2cb6be4276 100644
--- a/pydantic_ai_slim/pydantic_ai/result.py
+++ b/pydantic_ai_slim/pydantic_ai/result.py
@@ -71,19 +71,19 @@ class _BaseRunResult(ABC, Generic[ResultData]):
You should not import or use this type directly, instead use its subclasses `RunResult` and `StreamedRunResult`.
"""
- _all_messages: list[_messages.Message]
+ _all_messages: list[_messages.ModelMessage]
_new_message_index: int
- def all_messages(self) -> list[_messages.Message]:
+ def all_messages(self) -> list[_messages.ModelMessage]:
"""Return the history of _messages."""
# this is a method to be consistent with the other methods
return self._all_messages
def all_messages_json(self) -> bytes:
"""Return all messages from [`all_messages`][pydantic_ai.result._BaseRunResult.all_messages] as JSON bytes."""
- return _messages.MessagesTypeAdapter.dump_json(self.all_messages())
+ return _messages.ModelMessagesTypeAdapter.dump_json(self.all_messages())
- def new_messages(self) -> list[_messages.Message]:
+ def new_messages(self) -> list[_messages.ModelMessage]:
"""Return new messages associated with this run.
System prompts and any messages from older runs are excluded.
@@ -92,7 +92,7 @@ def new_messages(self) -> list[_messages.Message]:
def new_messages_json(self) -> bytes:
"""Return new messages from [`new_messages`][pydantic_ai.result._BaseRunResult.new_messages] as JSON bytes."""
- return _messages.MessagesTypeAdapter.dump_json(self.new_messages())
+ return _messages.ModelMessagesTypeAdapter.dump_json(self.new_messages())
@abstractmethod
def cost(self) -> Cost:
@@ -122,7 +122,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
_result_schema: _result.ResultSchema[ResultData] | None
_deps: AgentDeps
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
- _on_complete: Callable[[list[_messages.Message]], None]
+ _on_complete: Callable[[list[_messages.ModelMessage]], None]
is_complete: bool = field(default=False, init=False)
"""Whether the stream has all been received.
diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py
index 9f9b4fe8a4..5e22c2c312 100644
--- a/pydantic_ai_slim/pydantic_ai/tools.py
+++ b/pydantic_ai_slim/pydantic_ai/tools.py
@@ -45,7 +45,7 @@ class RunContext(Generic[AgentDeps]):
"""Dependencies for the agent."""
retry: int
"""Number of retries so far."""
- messages: list[_messages.Message]
+ messages: list[_messages.ModelMessage]
"""Messages exchanged in the conversation so far."""
tool_name: str | None
"""Name of the tool being called."""
@@ -89,7 +89,7 @@ class RunContext(Generic[AgentDeps]):
Usage `ToolPlainFunc[ToolParams]`.
"""
ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
-"""Either kind of tool function.
+"""Either part_kind of tool function.
This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
@@ -238,8 +238,8 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDeps]) -> ToolDefinition |
return tool_def
async def run(
- self, deps: AgentDeps, message: _messages.ToolCallPart, conv_messages: list[_messages.Message]
- ) -> _messages.Message:
+ self, deps: AgentDeps, message: _messages.ToolCallPart, conv_messages: list[_messages.ModelMessage]
+ ) -> _messages.ModelRequestPart:
"""Run the tool function asynchronously."""
try:
if isinstance(message.args, _messages.ArgsJson):
@@ -261,7 +261,7 @@ async def run(
return self._on_error(e, message)
self.current_retry = 0
- return _messages.ToolReturn(
+ return _messages.ToolReturnPart(
tool_name=message.tool_name,
content=response_content,
tool_call_id=message.tool_call_id,
@@ -272,7 +272,7 @@ def _call_args(
deps: AgentDeps,
args_dict: dict[str, Any],
message: _messages.ToolCallPart,
- conv_messages: list[_messages.Message],
+ conv_messages: list[_messages.ModelMessage],
) -> tuple[list[Any], dict[str, Any]]:
if self._single_arg_name:
args_dict = {self._single_arg_name: args_dict}
@@ -287,7 +287,7 @@ def _call_args(
def _on_error(
self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart
- ) -> _messages.RetryPrompt:
+ ) -> _messages.RetryPromptPart:
self.current_retry += 1
if self.max_retries is None or self.current_retry > self.max_retries:
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
@@ -296,7 +296,7 @@ def _on_error(
content = exc.errors(include_url=False)
else:
content = exc.message
- return _messages.RetryPrompt(
+ return _messages.RetryPromptPart(
tool_name=call_message.tool_name,
content=content,
tool_call_id=call_message.tool_call_id,
diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py
index 3695c0983b..a21427f467 100644
--- a/tests/models/test_anthropic.py
+++ b/tests/models/test_anthropic.py
@@ -12,12 +12,13 @@
from pydantic_ai import Agent, ModelRetry
from pydantic_ai.messages import (
ArgsDict,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.result import Cost
@@ -100,9 +101,9 @@ async def test_sync_request_text_response(allow_model_requests: None):
assert result.cost() == snapshot(Cost(request_tokens=5, response_tokens=10, total_tokens=15))
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)),
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -135,7 +136,7 @@ async def test_request_structured_response(allow_model_requests: None):
assert result.data == [1, 2, 3]
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(
@@ -146,11 +147,15 @@ async def test_request_structured_response(allow_model_requests: None):
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='123',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='123',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
@@ -187,8 +192,12 @@ async def get_location(loc_name: str) -> str:
assert result.data == 'final response'
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt'),
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -199,11 +208,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(
- tool_name='get_location',
- content='Wrong location, please try again',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='Wrong location, please try again',
+ tool_name='get_location',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -215,11 +228,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- tool_call_id='2',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location',
+ content='{"lat": 51, "lng": 0}',
+ tool_call_id='2',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py
index abf05472e9..6820620af1 100644
--- a/tests/models/test_gemini.py
+++ b/tests/models/test_gemini.py
@@ -15,12 +15,13 @@
from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, UserError
from pydantic_ai.messages import (
ArgsDict,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.models.gemini import (
ApiKeyAuth,
@@ -391,7 +392,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
assert result.data == 'Hello world'
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -401,9 +402,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
assert result.data == 'Hello world'
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -423,7 +424,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
assert result.data == [1, 2, 123]
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(
@@ -433,10 +434,12 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
),
]
)
@@ -471,8 +474,12 @@ async def get_location(loc_name: str) -> str:
assert result.data == 'final response'
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -482,8 +489,14 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(
- tool_name='get_location', content='Wrong location, please try again', timestamp=IsNow(tz=timezone.utc)
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='Wrong location, please try again',
+ tool_name='get_location',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -494,7 +507,13 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -636,7 +655,7 @@ async def bar(y: str) -> str:
assert result.cost() == snapshot(Cost(request_tokens=3, response_tokens=6, total_tokens=9))
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(tool_name='foo', args=ArgsDict(args_dict={'x': 'a'})),
@@ -644,12 +663,18 @@ async def bar(y: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(tool_name='foo', content='a', timestamp=IsNow(tz=timezone.utc)),
+ ToolReturnPart(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
),
ModelResponse(
parts=[
diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py
index 9f54049249..52004134a6 100644
--- a/tests/models/test_groq.py
+++ b/tests/models/test_groq.py
@@ -14,12 +14,13 @@
from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, _utils
from pydantic_ai.messages import (
ArgsJson,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.result import Cost
@@ -137,9 +138,9 @@ async def test_request_simple_success(allow_model_requests: None):
assert result.cost() == snapshot(Cost())
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)
@@ -180,7 +181,7 @@ async def test_request_structured_response(allow_model_requests: None):
assert result.data == [1, 2, 123]
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(
@@ -191,11 +192,15 @@ async def test_request_structured_response(allow_model_requests: None):
],
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='123',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='123',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
@@ -256,8 +261,12 @@ async def get_location(loc_name: str) -> str:
assert result.data == 'final response'
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -268,11 +277,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
- RetryPrompt(
- tool_name='get_location',
- content='Wrong location, please try again',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ tool_name='get_location',
+ content='Wrong location, please try again',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -284,11 +297,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- tool_call_id='2',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location',
+ content='{"lat": 51, "lng": 0}',
+ tool_call_id='2',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse.from_text(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)),
]
@@ -397,12 +414,13 @@ async def test_stream_structured(allow_model_requests: None):
assert result.cost() == snapshot(Cost())
assert result.all_messages() == snapshot(
[
- UserPrompt(content='', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id=None,
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
),
ModelResponse(
parts=[
diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py
index 765c312928..db6eab886b 100644
--- a/tests/models/test_mistral.py
+++ b/tests/models/test_mistral.py
@@ -18,12 +18,13 @@
from pydantic_ai.messages import (
ArgsDict,
ArgsJson,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from ..conftest import IsNow, try_import
@@ -236,9 +237,9 @@ async def test_multiple_completions(allow_model_requests: None):
assert result.cost().total_tokens == 1
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)),
- UserPrompt(content='hello again', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='hello again', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -280,11 +281,11 @@ async def test_three_completions(allow_model_requests: None):
assert result.cost().total_tokens == 1
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=IsNow(tz=timezone.utc)),
- UserPrompt(content='hello again', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='hello again', timestamp=IsNow(tz=timezone.utc)),
- UserPrompt(content='final message', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='final message', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -387,7 +388,7 @@ class CityLocation(BaseModel):
assert result.data == CityLocation(city='paris', country='france')
assert result.all_messages() == snapshot(
[
- UserPrompt(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(
@@ -398,11 +399,15 @@ class CityLocation(BaseModel):
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='123',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='123',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
@@ -443,7 +448,7 @@ class CityLocation(BaseModel):
assert result.data == CityLocation(city='paris', country='france')
assert result.all_messages() == snapshot(
[
- UserPrompt(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(
@@ -454,11 +459,15 @@ class CityLocation(BaseModel):
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='123',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='123',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
@@ -490,8 +499,12 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
assert result.data == 42
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='System prompt value'),
- UserPrompt(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='System prompt value'),
+ UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -502,11 +515,15 @@ async def test_request_result_type_with_arguments_str_response(allow_model_reque
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='123',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='123',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
@@ -995,8 +1012,12 @@ async def get_location(loc_name: str) -> str:
assert result.data == 'final response'
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -1007,11 +1028,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(
- tool_name='get_location',
- content='Wrong location, please try again',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='Wrong location, please try again',
+ tool_name='get_location',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -1023,11 +1048,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- tool_call_id='2',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location',
+ content='{"lat": 51, "lng": 0}',
+ tool_call_id='2',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
@@ -1117,8 +1146,12 @@ async def get_location(loc_name: str) -> str:
assert result.data == {'lat': 51, 'lng': 0}
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -1129,11 +1162,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(
- tool_name='get_location',
- content='Wrong location, please try again',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='Wrong location, please try again',
+ tool_name='get_location',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -1145,11 +1182,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- tool_call_id='2',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location',
+ content='{"lat": 51, "lng": 0}',
+ tool_call_id='2',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -1161,11 +1202,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
@@ -1241,12 +1286,11 @@ async def get_location(loc_name: str) -> str:
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt', role='user', message_kind='system-prompt'),
- UserPrompt(
- content='User prompt value',
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='user-prompt',
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
+ ]
),
ModelResponse(
parts=[
@@ -1258,29 +1302,31 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='tool-return',
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location',
+ content='{"lat": 51, "lng": 0}',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='tool-return',
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"won": true}'), tool_call_id='1')
],
timestamp=IsNow(tz=timezone.utc),
- role='model',
- message_kind='model-response',
),
]
)
@@ -1341,8 +1387,12 @@ async def get_location(loc_name: str) -> str:
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt'),
- UserPrompt(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='User prompt value', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -1353,11 +1403,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location',
+ content='{"lat": 51, "lng": 0}',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py
index d9e6dc9657..ea11ee7996 100644
--- a/tests/models/test_model_function.py
+++ b/tests/models/test_model_function.py
@@ -12,12 +12,14 @@
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.messages import (
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
- SystemPrompt,
+ SystemPromptPart,
+ TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.test import TestModel
@@ -28,8 +30,8 @@
pytestmark = pytest.mark.anyio
-async def return_last(messages: list[Message], _: AgentInfo) -> ModelResponse:
- last = messages[-1]
+async def return_last(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse:
+ last = messages[-1].parts[-1]
response = asdict(last)
response.pop('timestamp', None)
response['message_count'] = len(messages)
@@ -39,50 +41,40 @@ async def return_last(messages: list[Message], _: AgentInfo) -> ModelResponse:
def test_simple(set_event_loop: None):
agent = Agent(FunctionModel(return_last))
result = agent.run_sync('Hello')
- assert result.data == snapshot("content='Hello' role='user' message_kind='user-prompt' message_count=1")
+ assert result.data == snapshot("content='Hello' part_kind='user-prompt' message_count=1")
assert result.all_messages() == snapshot(
[
- UserPrompt(
- content='Hello',
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='user-prompt',
- ),
- ModelResponse.from_text(
- content="content='Hello' role='user' message_kind='user-prompt' message_count=1",
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
+ ModelResponse(
+ parts=[TextPart(content="content='Hello' part_kind='user-prompt' message_count=1")],
timestamp=IsNow(tz=timezone.utc),
),
]
)
result2 = agent.run_sync('World', message_history=result.all_messages())
- assert result2.data == snapshot("content='World' role='user' message_kind='user-prompt' message_count=3")
+ assert result2.data == snapshot("content='World' part_kind='user-prompt' message_count=3")
assert result2.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc), role='user', message_kind='user-prompt'),
- ModelResponse.from_text(
- content="content='Hello' role='user' message_kind='user-prompt' message_count=1",
- timestamp=IsNow(tz=timezone.utc),
- ),
- UserPrompt(
- content='World',
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
+ ModelResponse(
+ parts=[TextPart(content="content='Hello' part_kind='user-prompt' message_count=1")],
timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='user-prompt',
),
- ModelResponse.from_text(
- content="content='World' role='user' message_kind='user-prompt' message_count=3",
+ ModelRequest(parts=[UserPromptPart(content='World', timestamp=IsNow(tz=timezone.utc))]),
+ ModelResponse(
+ parts=[TextPart(content="content='World' part_kind='user-prompt' message_count=3")],
timestamp=IsNow(tz=timezone.utc),
),
]
)
-async def weather_model(messages: list[Message], info: AgentInfo) -> ModelResponse: # pragma: no cover
+async def weather_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover
assert info.allow_text_result
assert {t.name for t in info.function_tools} == {'get_location', 'get_weather'}
- last = messages[-1]
- if isinstance(last, UserPrompt):
+ last = messages[-1].parts[-1]
+ if isinstance(last, UserPromptPart):
return ModelResponse(
parts=[
ToolCallPart.from_json(
@@ -91,11 +83,17 @@ async def weather_model(messages: list[Message], info: AgentInfo) -> ModelRespon
)
]
)
- elif isinstance(last, ToolReturn):
+ elif isinstance(last, ToolReturnPart):
if last.tool_name == 'get_location':
return ModelResponse(parts=[ToolCallPart.from_json('get_weather', last.model_response_str())])
elif last.tool_name == 'get_weather':
- location_name = next(m.content for m in messages if isinstance(m, UserPrompt))
+ location_name: str | None = None
+ for m in messages:
+ location_name = next((part.content for part in m.parts if isinstance(part, UserPromptPart)), None)
+ if location_name is not None:
+ break
+
+ assert location_name is not None
return ModelResponse.from_text(f'{last.content} in {location_name}')
raise ValueError(f'Unexpected message: {last}')
@@ -127,24 +125,17 @@ def test_weather(set_event_loop: None):
assert result.data == 'Raining in London'
assert result.all_messages() == snapshot(
[
- UserPrompt(
- content='London',
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='user-prompt',
- ),
+ ModelRequest(parts=[UserPromptPart(content='London', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart.from_json('get_location', '{"location_description": "London"}')],
timestamp=IsNow(tz=timezone.utc),
- role='model',
- message_kind='model-response',
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='tool-return',
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location', content='{"lat": 51, "lng": 0}', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
),
ModelResponse(
parts=[
@@ -154,15 +145,9 @@ def test_weather(set_event_loop: None):
)
],
timestamp=IsNow(tz=timezone.utc),
- role='model',
- message_kind='model-response',
),
- ToolReturn(
- tool_name='get_weather',
- content='Raining',
- timestamp=IsNow(tz=timezone.utc),
- role='user',
- message_kind='tool-return',
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='get_weather', content='Raining', timestamp=IsNow(tz=timezone.utc))]
),
ModelResponse.from_text(
content='Raining in London',
@@ -175,9 +160,9 @@ def test_weather(set_event_loop: None):
assert result.data == 'Sunny in Ipswich'
-async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelResponse: # pragma: no cover
- last = messages[-1]
- if isinstance(last, UserPrompt):
+async def call_function_model(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse: # pragma: no cover
+ last = messages[-1].parts[-1]
+ if isinstance(last, UserPromptPart):
if last.content.startswith('{'):
details = json.loads(last.content)
return ModelResponse(
@@ -188,7 +173,7 @@ async def call_function_model(messages: list[Message], _: AgentInfo) -> ModelRes
)
]
)
- elif isinstance(last, ToolReturn):
+ elif isinstance(last, ToolReturnPart):
return ModelResponse.from_text(pydantic_core.to_json(last).decode())
raise ValueError(f'Unexpected message: {last}')
@@ -214,13 +199,12 @@ def test_var_args(set_event_loop: None):
'content': '{"args": [1, 2, 3]}',
'tool_call_id': None,
'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc),
- 'role': 'user',
- 'message_kind': 'tool-return',
+ 'part_kind': 'tool-return',
}
)
-async def call_tool(messages: list[Message], info: AgentInfo) -> ModelResponse:
+async def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
assert len(info.function_tools) == 1
tool_name = info.function_tools[0].name
@@ -267,7 +251,7 @@ def get_check_foobar(ctx: RunContext[tuple[str, str]]) -> str:
def test_model_arg(set_event_loop: None):
agent = Agent()
result = agent.run_sync('Hello', model=FunctionModel(return_last))
- assert result.data == snapshot("content='Hello' role='user' message_kind='user-prompt' message_count=1")
+ assert result.data == snapshot("content='Hello' part_kind='user-prompt' message_count=1")
with pytest.raises(RuntimeError, match='`model` must be set either when creating the agent or when calling it.'):
agent.run_sync('Hello')
@@ -307,13 +291,13 @@ def spam() -> str:
def test_register_all(set_event_loop: None):
- async def f(messages: list[Message], info: AgentInfo) -> ModelResponse:
+ async def f(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse.from_text(
f'messages={len(messages)} allow_text_result={info.allow_text_result} tools={len(info.function_tools)}'
)
result = agent_all.run_sync('Hello', model=FunctionModel(f))
- assert result.data == snapshot('messages=2 allow_text_result=True tools=5')
+ assert result.data == snapshot('messages=1 allow_text_result=True tools=5')
def test_call_all(set_event_loop: None):
@@ -321,8 +305,12 @@ def test_call_all(set_event_loop: None):
assert result.data == snapshot('{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}')
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='foobar'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='foobar'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart.from_dict('foo', {'x': 0}),
@@ -333,11 +321,15 @@ def test_call_all(set_event_loop: None):
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='foo', content='1', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(tool_name='bar', content='2', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(tool_name='baz', content='3', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(tool_name='qux', content='4', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(tool_name='quz', content='a', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(tool_name='foo', content='1', timestamp=IsNow(tz=timezone.utc)),
+ ToolReturnPart(tool_name='bar', content='2', timestamp=IsNow(tz=timezone.utc)),
+ ToolReturnPart(tool_name='baz', content='3', timestamp=IsNow(tz=timezone.utc)),
+ ToolReturnPart(tool_name='qux', content='4', timestamp=IsNow(tz=timezone.utc)),
+ ToolReturnPart(tool_name='quz', content='a', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse.from_text(
content='{"foo":"1","bar":"2","baz":"3","qux":"4","quz":"a"}', timestamp=IsNow(tz=timezone.utc)
),
@@ -348,7 +340,7 @@ def test_call_all(set_event_loop: None):
def test_retry_str(set_event_loop: None):
call_count = 0
- async def try_again(msgs_: list[Message], _agent_info: AgentInfo) -> ModelResponse:
+ async def try_again(msgs_: list[ModelMessage], _agent_info: AgentInfo) -> ModelResponse:
nonlocal call_count
call_count += 1
@@ -370,7 +362,7 @@ async def validate_result(r: str) -> str:
def test_retry_result_type(set_event_loop: None):
call_count = 0
- async def try_again(messages: list[Message], _: AgentInfo) -> ModelResponse:
+ async def try_again(messages: list[ModelMessage], _: AgentInfo) -> ModelResponse:
nonlocal call_count
call_count += 1
@@ -392,7 +384,7 @@ async def validate_result(r: Foo) -> Foo:
assert result.data == snapshot(Foo(x=2))
-async def stream_text_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]:
+async def stream_text_function(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[str]:
yield 'hello '
yield 'world'
@@ -403,7 +395,7 @@ async def test_stream_text():
assert await result.get_data() == snapshot('hello world')
assert result.all_messages() == snapshot(
[
- UserPrompt(content='', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='hello world', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -416,7 +408,7 @@ class Foo(BaseModel):
async def test_stream_structure():
async def stream_structured_function(
- _messages: list[Message], agent_info: AgentInfo
+ _messages: list[ModelMessage], agent_info: AgentInfo
) -> AsyncIterator[DeltaToolCalls]:
assert agent_info.result_tools is not None
assert len(agent_info.result_tools) == 1
@@ -440,7 +432,7 @@ async def test_pass_both():
Agent(FunctionModel(return_last, stream_function=stream_text_function))
-async def stream_text_function_empty(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]:
+async def stream_text_function_empty(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[str]:
if False:
yield 'hello '
diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py
index 2d57d92102..80f9a9656d 100644
--- a/tests/models/test_model_test.py
+++ b/tests/models/test_model_test.py
@@ -12,11 +12,12 @@
from pydantic_ai import Agent, ModelRetry
from pydantic_ai.messages import (
+ ModelRequest,
ModelResponse,
- RetryPrompt,
+ RetryPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.models.test import TestModel, _chars, _JsonSchemaTestData # pyright: ignore[reportPrivateUsage]
@@ -91,14 +92,18 @@ async def my_ret(x: int) -> str:
assert result.data == snapshot('{"my_ret":"1"}')
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart.from_dict('my_ret', {'x': 0})],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(tool_name='my_ret', content='First call failed', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(content='First call failed', tool_name='my_ret', timestamp=IsNow(tz=timezone.utc))
+ ]
+ ),
ModelResponse(parts=[ToolCallPart.from_dict('my_ret', {'x': 0})], timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(tool_name='my_ret', content='1', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[ToolReturnPart(tool_name='my_ret', content='1', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='{"my_ret":"1"}', timestamp=IsNow(tz=timezone.utc)),
]
)
diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py
index e59c3c03fd..d0e5088ef4 100644
--- a/tests/models/test_ollama.py
+++ b/tests/models/test_ollama.py
@@ -7,8 +7,9 @@
from pydantic_ai import Agent
from pydantic_ai.messages import (
+ ModelRequest,
ModelResponse,
- UserPrompt,
+ UserPromptPart,
)
from pydantic_ai.result import Cost
@@ -53,9 +54,9 @@ async def test_request_simple_success(allow_model_requests: None):
assert result.cost() == snapshot(Cost())
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)
diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py
index c48b0279c1..50d0c5779b 100644
--- a/tests/models/test_openai.py
+++ b/tests/models/test_openai.py
@@ -14,12 +14,13 @@
from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior, _utils
from pydantic_ai.messages import (
ArgsJson,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.result import Cost
@@ -146,9 +147,9 @@ async def test_request_simple_success(allow_model_requests: None):
assert result.cost() == snapshot(Cost())
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='world', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc)),
]
)
@@ -190,7 +191,7 @@ async def test_request_structured_response(allow_model_requests: None):
assert result.data == [1, 2, 123]
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[
ToolCallPart(
@@ -201,11 +202,15 @@ async def test_request_structured_response(allow_model_requests: None):
],
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id='123',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ tool_call_id='123',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
@@ -268,8 +273,12 @@ async def get_location(loc_name: str) -> str:
assert result.data == 'final response'
assert result.all_messages() == snapshot(
[
- SystemPrompt(content='this is the system prompt'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='this is the system prompt'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(
@@ -280,11 +289,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
),
- RetryPrompt(
- tool_name='get_location',
- content='Wrong location, please try again',
- tool_call_id='1',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='Wrong location, please try again',
+ tool_name='get_location',
+ tool_call_id='1',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -296,11 +309,15 @@ async def get_location(loc_name: str) -> str:
],
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
),
- ToolReturn(
- tool_name='get_location',
- content='{"lat": 51, "lng": 0}',
- tool_call_id='2',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='get_location',
+ content='{"lat": 51, "lng": 0}',
+ tool_call_id='2',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse.from_text(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)),
]
diff --git a/tests/test_agent.py b/tests/test_agent.py
index 3f8ca32180..a079e245e9 100644
--- a/tests/test_agent.py
+++ b/tests/test_agent.py
@@ -13,13 +13,15 @@
from pydantic_ai.messages import (
ArgsDict,
ArgsJson,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
- SystemPrompt,
+ RetryPromptPart,
+ SystemPromptPart,
+ TextPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.models import cached_async_http_client
from pydantic_ai.models.function import AgentInfo, FunctionModel
@@ -33,7 +35,7 @@
def test_result_tuple(set_event_loop: None):
- def return_tuple(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
args_json = '{"response": ["foo", "bar"]}'
return ModelResponse(parts=[ToolCallPart.from_json(info.result_tools[0].name, args_json)])
@@ -50,7 +52,7 @@ class Foo(BaseModel):
def test_result_pydantic_model(set_event_loop: None):
- def return_model(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
args_json = '{"a": 1, "b": "foo"}'
return ModelResponse(parts=[ToolCallPart.from_json(info.result_tools[0].name, args_json)])
@@ -63,7 +65,7 @@ def return_model(_: list[Message], info: AgentInfo) -> ModelResponse:
def test_result_pydantic_model_retry(set_event_loop: None):
- def return_model(messages: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
if len(messages) == 1:
args_json = '{"a": "wrong", "b": "foo"}'
@@ -81,35 +83,45 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelResponse:
assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart.from_json('final_result', '{"a": "wrong", "b": "foo"}')],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(
- tool_name='final_result',
- content=[
- {
- 'type': 'int_parsing',
- 'loc': ('a',),
- 'msg': 'Input should be a valid integer, unable to parse string as an integer',
- 'input': 'wrong',
- }
- ],
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ tool_name='final_result',
+ content=[
+ {
+ 'type': 'int_parsing',
+ 'loc': ('a',),
+ 'msg': 'Input should be a valid integer, unable to parse string as an integer',
+ 'input': 'wrong',
+ }
+ ],
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[ToolCallPart.from_json('final_result', '{"a": 42, "b": "foo"}')],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
]
)
- assert result.all_messages_json().startswith(b'[{"content":"Hello"')
+ assert result.all_messages_json().startswith(b'[{"parts":[{"content":"Hello",')
def test_result_pydantic_model_validation_error(set_event_loop: None):
- def return_model(messages: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
if len(messages) == 1:
args_json = '{"a": 1, "b": "foo"}'
@@ -132,13 +144,21 @@ def check_b(cls, v: str) -> str:
result = agent.run_sync('Hello')
assert isinstance(result.data, Bar)
assert result.data.model_dump() == snapshot({'a': 1, 'b': 'bar'})
- messages_kinds = [m.message_kind for m in result.all_messages()]
- assert messages_kinds == snapshot(
- ['user-prompt', 'model-response', 'retry-prompt', 'model-response', 'tool-return']
+ messages_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result.all_messages()]
+ assert messages_part_kinds == snapshot(
+ [
+ ('request', ['user-prompt']),
+ ('response', ['tool-call']),
+ ('request', ['retry-prompt']),
+ ('response', ['tool-call']),
+ ('request', ['tool-return']),
+ ]
)
- retry_prompt = result.all_messages()[2]
- assert isinstance(retry_prompt, RetryPrompt)
+ user_retry = result.all_messages()[2]
+ assert isinstance(user_retry, ModelRequest)
+ retry_prompt = user_retry.parts[0]
+ assert isinstance(retry_prompt, RetryPromptPart)
assert retry_prompt.model_response() == snapshot("""\
1 validation errors: [
{
@@ -155,7 +175,7 @@ def check_b(cls, v: str) -> str:
def test_result_validator(set_event_loop: None):
- def return_model(messages: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
if len(messages) == 1:
args_json = '{"a": 41, "b": "foo"}'
@@ -178,17 +198,29 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
assert result.data.model_dump() == {'a': 42, 'b': 'foo'}
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart.from_json('final_result', '{"a": 41, "b": "foo"}')],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(tool_name='final_result', content='"a" should be 42', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='"a" should be 42', tool_name='final_result', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
ModelResponse(
parts=[ToolCallPart.from_json('final_result', '{"a": 42, "b": "foo"}')],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
]
)
@@ -196,7 +228,7 @@ def validate_result(ctx: RunContext[None], r: Foo) -> Foo:
def test_plain_response(set_event_loop: None):
call_index = 0
- def return_tuple(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
nonlocal call_index
assert info.result_tools is not None
@@ -214,17 +246,27 @@ def return_tuple(_: list[Message], info: AgentInfo) -> ModelResponse:
assert call_index == 2
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse.from_text(content='hello', timestamp=IsNow(tz=timezone.utc)),
- RetryPrompt(
- content='Plain text responses are not permitted, please call one of the functions instead.',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content='Plain text responses are not permitted, please call one of the functions instead.',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"response": ["foo", "bar"]}'))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
]
)
@@ -422,13 +464,20 @@ async def ret_a(x: str) -> str:
result1 = agent.run_sync('Hello')
assert result1.new_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='Foobar'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
- ModelResponse.from_text(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
+ ),
+ ModelResponse(parts=[TextPart(content='{"ret_a":"a-apple"}')], timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -437,41 +486,40 @@ async def ret_a(x: str) -> str:
assert result2 == snapshot(
RunResult(
_all_messages=[
- SystemPrompt(content='Foobar'),
- UserPrompt(
- content='Hello',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='Foobar'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='ret_a',
- content='a-apple',
- timestamp=IsNow(tz=timezone.utc),
- ),
- ModelResponse.from_text(
- content='{"ret_a":"a-apple"}',
- timestamp=IsNow(tz=timezone.utc),
- ),
- UserPrompt(
- content='Hello again',
- timestamp=IsNow(tz=timezone.utc),
- ),
- ModelResponse.from_text(
- content='{"ret_a":"a-apple"}',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
),
+ ModelResponse(parts=[TextPart(content='{"ret_a":"a-apple"}')], timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]),
+ ModelResponse(parts=[TextPart(content='{"ret_a":"a-apple"}')], timestamp=IsNow(tz=timezone.utc)),
],
- _new_message_index=5,
+ _new_message_index=4,
data='{"ret_a":"a-apple"}',
_cost=Cost(),
)
)
- new_msg_kinds = [msg.message_kind for msg in result2.new_messages()]
- assert new_msg_kinds == snapshot(['user-prompt', 'model-response'])
- assert result2.new_messages_json().startswith(b'[{"content":"Hello again",')
+ new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()]
+ assert new_msg_part_kinds == snapshot(
+ [
+ ('request', ['system-prompt', 'user-prompt']),
+ ('response', ['tool-call']),
+ ('request', ['tool-return']),
+ ('response', ['text']),
+ ('request', ['user-prompt']),
+ ('response', ['text']),
+ ]
+ )
+ assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",')
# if we pass all_messages, system prompt is NOT inserted before the message_history messages,
# so only one system prompt
@@ -479,21 +527,26 @@ async def ret_a(x: str) -> str:
# same as result2 except for datetimes
assert result3 == snapshot(
RunResult(
- data='{"ret_a":"a-apple"}',
_all_messages=[
- SystemPrompt(content='Foobar'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='Foobar'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
- ModelResponse.from_text(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
- # second call, notice no repeated system prompt
- UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)),
- ModelResponse.from_text(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
+ ),
+ ModelResponse(parts=[TextPart(content='{"ret_a":"a-apple"}')], timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc))]),
+ ModelResponse(parts=[TextPart(content='{"ret_a":"a-apple"}')], timestamp=IsNow(tz=timezone.utc)),
],
- _new_message_index=5,
+ _new_message_index=4,
+ data='{"ret_a":"a-apple"}',
_cost=Cost(),
)
)
@@ -514,103 +567,106 @@ async def ret_a(x: str) -> str:
result1 = agent.run_sync('Hello')
assert result1.new_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='Foobar'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ]
+ ),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
+ ),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'a': 0}), tool_call_id=None)],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- tool_call_id=None,
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
),
]
)
- # if we pass new_messages, system prompt is inserted before the message_history messages
result2 = agent.run_sync('Hello again', message_history=result1.new_messages())
assert result2 == snapshot(
RunResult(
data=Response(a=0),
_all_messages=[
- SystemPrompt(content='Foobar'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
- ModelResponse(
- parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
- timestamp=IsNow(tz=timezone.utc),
- ),
- ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
- ModelResponse(
- parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'a': 0}))],
- timestamp=IsNow(tz=timezone.utc),
- ),
- ToolReturn(
- tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
- ),
- # second call, notice no repeated system prompt
- UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)),
- ModelResponse(
- parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'a': 0}))],
- timestamp=IsNow(tz=timezone.utc),
- ),
- ToolReturn(
- tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content='Foobar'),
+ UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ],
),
- ],
- _new_message_index=6,
- _cost=Cost(),
- )
- )
- new_msg_kinds = [msg.message_kind for msg in result2.new_messages()]
- assert new_msg_kinds == snapshot(['user-prompt', 'model-response', 'tool-return'])
- assert result2.new_messages_json().startswith(b'[{"content":"Hello again",')
-
- # if we pass all_messages, system prompt is NOT inserted before the message_history messages,
- # so only one system prompt
- result3 = agent.run_sync('Hello again', message_history=result1.all_messages())
- # same as result2 except for datetimes
- assert result3 == snapshot(
- RunResult(
- data=Response(a=0),
- _all_messages=[
- SystemPrompt(content='Foobar'),
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))],
+ ),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'a': 0}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ],
),
# second call, notice no repeated system prompt
- UserPrompt(content='Hello again', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ UserPromptPart(content='Hello again', timestamp=IsNow(tz=timezone.utc)),
+ ],
+ ),
ModelResponse(
parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args_dict={'a': 0}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ]
),
],
- _new_message_index=6,
+ _new_message_index=5,
_cost=Cost(),
)
)
+ new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()]
+ assert new_msg_part_kinds == snapshot(
+ [
+ ('request', ['system-prompt', 'user-prompt']),
+ ('response', ['tool-call']),
+ ('request', ['tool-return']),
+ ('response', ['tool-call']),
+ ('request', ['tool-return']),
+ ('request', ['user-prompt']),
+ ('response', ['tool-call']),
+ ('request', ['tool-return']),
+ ]
+ )
+ assert result2.new_messages_json().startswith(b'[{"parts":[{"content":"Hello again",')
def test_empty_tool_calls(set_event_loop: None):
- def empty(_: list[Message], _info: AgentInfo) -> ModelResponse:
+ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[])
agent = Agent(FunctionModel(empty))
@@ -620,7 +676,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelResponse:
def test_unknown_tool(set_event_loop: None):
- def empty(_: list[Message], _info: AgentInfo) -> ModelResponse:
+ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[ToolCallPart.from_json('foobar', '{}')])
agent = Agent(FunctionModel(empty))
@@ -629,12 +685,18 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelResponse:
agent.run_sync('Hello')
assert agent.last_run_messages == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args=ArgsJson(args_json='{}'))],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args=ArgsJson(args_json='{}'))],
timestamp=IsNow(tz=timezone.utc),
@@ -644,7 +706,7 @@ def empty(_: list[Message], _info: AgentInfo) -> ModelResponse:
def test_unknown_tool_fix(set_event_loop: None):
- def empty(m: list[Message], _info: AgentInfo) -> ModelResponse:
+ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
if len(m) > 1:
return ModelResponse.from_text(content='success')
else:
@@ -656,12 +718,18 @@ def empty(m: list[Message], _info: AgentInfo) -> ModelResponse:
assert result.data == 'success'
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args=ArgsJson(args_json='{}'))],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content="Unknown tool name: 'foobar'. No tools available.", timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
ModelResponse.from_text(content='success', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -778,7 +846,7 @@ def test_early_strategy_stops_after_first_final_result(self):
"""Test that 'early' strategy stops processing regular tools after first final result."""
tool_called = []
- def return_model(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
return ModelResponse(
parts=[
@@ -809,18 +877,17 @@ def another_tool(y: int) -> int: # pragma: no cover
assert tool_called == []
# Verify we got tool returns for all calls
- tool_returns = [m for m in messages if isinstance(m, ToolReturn)]
- assert tool_returns == snapshot(
+ assert messages[-1].parts == snapshot(
[
- ToolReturn(
+ ToolReturnPart(
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
),
- ToolReturn(
+ ToolReturnPart(
tool_name='regular_tool',
content='Tool not executed - a final result was already processed.',
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
+ ToolReturnPart(
tool_name='another_tool',
content='Tool not executed - a final result was already processed.',
timestamp=IsNow(tz=timezone.utc),
@@ -831,7 +898,7 @@ def another_tool(y: int) -> int: # pragma: no cover
def test_early_strategy_uses_first_final_result(self):
"""Test that 'early' strategy uses the first final result and ignores subsequent ones."""
- def return_model(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
return ModelResponse(
parts=[
@@ -847,13 +914,12 @@ def return_model(_: list[Message], info: AgentInfo) -> ModelResponse:
assert result.data.value == 'first'
# Verify we got appropriate tool returns
- tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturn)]
- assert tool_returns == snapshot(
+ assert result.new_messages()[-1].parts == snapshot(
[
- ToolReturn(
+ ToolReturnPart(
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
),
- ToolReturn(
+ ToolReturnPart(
tool_name='final_result',
content='Result tool not used - a final result was already processed.',
timestamp=IsNow(tz=timezone.utc),
@@ -865,7 +931,7 @@ def test_exhaustive_strategy_executes_all_tools(self):
"""Test that 'exhaustive' strategy executes all tools while using first final result."""
tool_called: list[str] = []
- def return_model(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
return ModelResponse(
parts=[
@@ -902,7 +968,9 @@ def another_tool(y: int) -> int:
# Verify we got tool returns in the correct order
assert result.all_messages() == snapshot(
[
- UserPrompt(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[UserPromptPart(content='test exhaustive strategy', timestamp=IsNow(tz=timezone.utc))]
+ ),
ModelResponse(
parts=[
ToolCallPart(tool_name='regular_tool', args=ArgsDict(args_dict={'x': 42})),
@@ -913,20 +981,26 @@ def another_tool(y: int) -> int:
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
- ),
- ToolReturn(
- tool_name='final_result',
- content='Result tool not used - a final result was already processed.',
- timestamp=IsNow(tz=timezone.utc),
- ),
- RetryPrompt(
- content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Result tool not used - a final result was already processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ RetryPromptPart(
+ content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ToolReturnPart(tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc)),
+ ToolReturnPart(tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc)),
+ ]
),
- ToolReturn(tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -934,7 +1008,7 @@ def test_early_strategy_with_final_result_in_middle(self):
"""Test that 'early' strategy stops at first final result, regardless of position."""
tool_called = []
- def return_model(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert info.result_tools is not None
return ModelResponse(
parts=[
@@ -967,7 +1041,13 @@ def another_tool(y: int) -> int: # pragma: no cover
# Verify we got appropriate tool returns
assert result.all_messages() == snapshot(
[
- UserPrompt(content='test early strategy with final result in middle', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[
+ UserPromptPart(
+ content='test early strategy with final result in middle', timestamp=IsNow(tz=timezone.utc)
+ )
+ ]
+ ),
ModelResponse(
parts=[
ToolCallPart(tool_name='regular_tool', args=ArgsDict(args_dict={'x': 1})),
@@ -977,22 +1057,28 @@ def another_tool(y: int) -> int: # pragma: no cover
],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(
- tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
- ),
- ToolReturn(
- tool_name='regular_tool',
- content='Tool not executed - a final result was already processed.',
- timestamp=IsNow(tz=timezone.utc),
- ),
- ToolReturn(
- tool_name='another_tool',
- content='Tool not executed - a final result was already processed.',
- timestamp=IsNow(tz=timezone.utc),
- ),
- RetryPrompt(
- content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ToolReturnPart(
+ tool_name='regular_tool',
+ content='Tool not executed - a final result was already processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ToolReturnPart(
+ tool_name='another_tool',
+ content='Tool not executed - a final result was already processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ RetryPromptPart(
+ content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
+ timestamp=IsNow(tz=timezone.utc),
+ ),
+ ]
),
]
)
@@ -1011,19 +1097,12 @@ def regular_tool(x: int) -> int:
result = agent.run_sync('test early strategy with regular tool calls')
assert tool_called == ['regular_tool']
- tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturn)]
- assert tool_returns == snapshot(
- [
- ToolReturn(tool_name='regular_tool', content=0, timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(
- tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
- ),
- ]
- )
+ tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)]
+ assert tool_returns == snapshot([])
async def test_model_settings_override() -> None:
- def return_settings(_: list[Message], info: AgentInfo) -> ModelResponse:
+ def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse.from_text(to_json(info.model_settings).decode())
my_agent = Agent(FunctionModel(return_settings))
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 7cb6a6ede2..64fca2f7d4 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -20,12 +20,12 @@
from pydantic_ai._utils import group_by_temporal
from pydantic_ai.messages import (
ArgsDict,
- Message,
+ ModelMessage,
ModelResponse,
- RetryPrompt,
+ RetryPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.models import KnownModelName, Model
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
@@ -216,9 +216,9 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response:
}
-async def model_logic(messages: list[Message], info: AgentInfo) -> ModelResponse: # pragma: no cover
- m = messages[-1]
- if isinstance(m, UserPrompt):
+async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover
+ m = messages[-1].parts[-1]
+ if isinstance(m, UserPromptPart):
if response := text_responses.get(m.content):
if isinstance(response, str):
return ModelResponse.from_text(content=response)
@@ -228,28 +228,28 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelResponse
if re.fullmatch(r'sql prompt \d+', m.content):
return ModelResponse.from_text(content='SELECT 1')
- elif isinstance(m, ToolReturn) and m.tool_name == 'roulette_wheel':
+ elif isinstance(m, ToolReturnPart) and m.tool_name == 'roulette_wheel':
win = m.content == 'winner'
return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args=ArgsDict({'response': win}))])
- elif isinstance(m, ToolReturn) and m.tool_name == 'roll_die':
+ elif isinstance(m, ToolReturnPart) and m.tool_name == 'roll_die':
return ModelResponse(parts=[ToolCallPart(tool_name='get_player_name', args=ArgsDict({}))])
- elif isinstance(m, ToolReturn) and m.tool_name == 'get_player_name':
+ elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_player_name':
return ModelResponse.from_text(content="Congratulations Anne, you guessed correctly! You're a winner!")
if (
- isinstance(m, RetryPrompt)
+ isinstance(m, RetryPromptPart)
and isinstance(m.content, str)
and m.content.startswith("No user found with name 'Joh")
):
return ModelResponse(parts=[ToolCallPart(tool_name='get_user_by_name', args=ArgsDict({'name': 'John Doe'}))])
- elif isinstance(m, ToolReturn) and m.tool_name == 'get_user_by_name':
+ elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_user_by_name':
args = {
'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!',
'user_id': 123,
}
return ModelResponse(parts=[ToolCallPart(tool_name='final_result', args=ArgsDict(args))])
- elif isinstance(m, RetryPrompt) and m.tool_name == 'calc_volume':
+ elif isinstance(m, RetryPromptPart) and m.tool_name == 'calc_volume':
return ModelResponse(parts=[ToolCallPart(tool_name='calc_volume', args=ArgsDict({'size': 6}))])
- elif isinstance(m, ToolReturn) and m.tool_name == 'customer_balance':
+ elif isinstance(m, ToolReturnPart) and m.tool_name == 'customer_balance':
args = {
'support_advice': 'Hello John, your current account balance, including pending transactions, is $123.45.',
'block_card': False,
@@ -262,10 +262,10 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelResponse
async def stream_model_logic(
- messages: list[Message], info: AgentInfo
+ messages: list[ModelMessage], info: AgentInfo
) -> AsyncIterator[str | DeltaToolCalls]: # pragma: no cover
- m = messages[-1]
- if isinstance(m, UserPrompt):
+ m = messages[-1].parts[-1]
+ if isinstance(m, UserPromptPart):
if response := text_responses.get(m.content):
if isinstance(response, str):
words = response.split(' ')
diff --git a/tests/test_logfire.py b/tests/test_logfire.py
index 90ec0f40fc..2d1d1cdb68 100644
--- a/tests/test_logfire.py
+++ b/tests/test_logfire.py
@@ -76,14 +76,14 @@ async def my_ret(x: int) -> str:
'message': 'my_agent run prompt=Hello',
'children': [
{'id': 1, 'message': 'preparing model and tools run_step=1'},
- {'id': 2, 'message': 'model request -> model-response'},
+ {'id': 2, 'message': 'model request'},
{
'id': 3,
'message': 'handle model response -> tool-return',
'children': [{'id': 4, 'message': "running tools=['my_ret']"}],
},
{'id': 5, 'message': 'preparing model and tools run_step=2'},
- {'id': 6, 'message': 'model request -> model-response'},
+ {'id': 6, 'message': 'model request'},
{'id': 7, 'message': 'handle model response -> final result'},
],
}
@@ -121,10 +121,14 @@ async def my_ret(x: int) -> str:
'all_messages': IsJson(
[
{
- 'content': 'Hello',
- 'timestamp': IsStr(regex=r'\d{4}-\d{2}-.+'),
- 'role': 'user',
- 'message_kind': 'user-prompt',
+ 'parts': [
+ {
+ 'content': 'Hello',
+ 'timestamp': IsStr(regex=r'\d{4}-\d{2}-.+'),
+ 'part_kind': 'user-prompt',
+ },
+ ],
+ 'kind': 'request',
},
{
'parts': [
@@ -136,22 +140,24 @@ async def my_ret(x: int) -> str:
}
],
'timestamp': IsStr(regex=r'\d{4}-\d{2}-.+'),
- 'role': 'model',
- 'message_kind': 'model-response',
+ 'kind': 'response',
},
{
- 'tool_name': 'my_ret',
- 'content': '1',
- 'tool_call_id': None,
- 'timestamp': IsStr(regex=r'\d{4}-\d{2}-.+'),
- 'role': 'user',
- 'message_kind': 'tool-return',
+ 'parts': [
+ {
+ 'tool_name': 'my_ret',
+ 'content': '1',
+ 'tool_call_id': None,
+ 'timestamp': IsStr(regex=r'\d{4}-\d{2}-.+'),
+ 'part_kind': 'tool-return',
+ },
+ ],
+ 'kind': 'request',
},
{
'parts': [{'content': '{"my_ret":"1"}', 'part_kind': 'text'}],
'timestamp': IsStr(regex=r'\d{4}-\d{2}-.+'),
- 'role': 'model',
- 'message_kind': 'model-response',
+ 'kind': 'response',
},
]
),
@@ -177,9 +183,19 @@ async def my_ret(x: int) -> str:
'prefixItems': [
{
'type': 'object',
- 'title': 'UserPrompt',
+ 'title': 'ModelRequest',
'x-python-datatype': 'dataclass',
- 'properties': {'timestamp': {'type': 'string', 'format': 'date-time'}},
+ 'properties': {
+ 'parts': {
+ 'type': 'array',
+ 'items': {
+ 'type': 'object',
+ 'title': 'UserPromptPart',
+ 'x-python-datatype': 'dataclass',
+ 'properties': {'timestamp': {'type': 'string', 'format': 'date-time'}},
+ },
+ }
+ },
},
{
'type': 'object',
@@ -206,9 +222,19 @@ async def my_ret(x: int) -> str:
},
{
'type': 'object',
- 'title': 'ToolReturn',
+ 'title': 'ModelRequest',
'x-python-datatype': 'dataclass',
- 'properties': {'timestamp': {'type': 'string', 'format': 'date-time'}},
+ 'properties': {
+ 'parts': {
+ 'type': 'array',
+ 'items': {
+ 'type': 'object',
+ 'title': 'ToolReturnPart',
+ 'x-python-datatype': 'dataclass',
+ 'properties': {'timestamp': {'type': 'string', 'format': 'date-time'}},
+ },
+ }
+ },
},
{
'type': 'object',
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index 6b76fc10c9..6dc06a007e 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -11,12 +11,13 @@
from pydantic_ai.messages import (
ArgsDict,
ArgsJson,
- Message,
+ ModelMessage,
+ ModelRequest,
ModelResponse,
- RetryPrompt,
+ RetryPromptPart,
ToolCallPart,
- ToolReturn,
- UserPrompt,
+ ToolReturnPart,
+ UserPromptPart,
)
from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel
from pydantic_ai.models.test import TestModel
@@ -43,12 +44,14 @@ async def ret_a(x: str) -> str:
assert not result.is_complete
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
+ ),
]
)
response = await result.get_data()
@@ -58,12 +61,14 @@ async def ret_a(x: str) -> str:
assert result.timestamp() == IsNow(tz=timezone.utc)
assert result.all_messages() == snapshot(
[
- UserPrompt(content='Hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))]
+ ),
ModelResponse.from_text(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
]
)
@@ -84,7 +89,7 @@ async def test_streamed_structured_response():
async def test_structured_response_iter():
- async def text_stream(_messages: list[Message], agent_info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
+ async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
assert agent_info.result_tools is not None
assert len(agent_info.result_tools) == 1
name = agent_info.result_tools[0].name
@@ -148,7 +153,7 @@ async def test_streamed_text_stream():
async def test_plain_response():
call_index = 0
- async def text_stream(_messages: list[Message], _: AgentInfo) -> AsyncIterator[str]:
+ async def text_stream(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[str]:
nonlocal call_index
call_index += 1
@@ -166,25 +171,27 @@ async def text_stream(_messages: list[Message], _: AgentInfo) -> AsyncIterator[s
async def test_call_tool():
async def stream_structured_function(
- messages: list[Message], agent_info: AgentInfo
+ messages: list[ModelMessage], agent_info: AgentInfo
) -> AsyncIterator[DeltaToolCalls | str]:
if len(messages) == 1:
assert agent_info.function_tools is not None
assert len(agent_info.function_tools) == 1
name = agent_info.function_tools[0].name
first = messages[0]
- assert isinstance(first, UserPrompt)
- json_string = json.dumps({'x': first.content})
+ assert isinstance(first, ModelRequest)
+ assert isinstance(first.parts[0], UserPromptPart)
+ json_string = json.dumps({'x': first.parts[0].content})
yield {0: DeltaToolCall(name=name)}
yield {0: DeltaToolCall(json_args=json_string[:3])}
yield {0: DeltaToolCall(json_args=json_string[3:])}
else:
last = messages[-1]
- assert isinstance(last, ToolReturn)
+ assert isinstance(last, ModelRequest)
+ assert isinstance(last.parts[0], ToolReturnPart)
assert agent_info.result_tools is not None
assert len(agent_info.result_tools) == 1
name = agent_info.result_tools[0].name
- json_data = json.dumps({'response': [last.content, 2]})
+ json_data = json.dumps({'response': [last.parts[0].content, 2]})
yield {0: DeltaToolCall(name=name)}
yield {0: DeltaToolCall(json_args=json_data[:5])}
yield {0: DeltaToolCall(json_args=json_data[5:])}
@@ -199,32 +206,44 @@ async def ret_a(x: str) -> str:
async with agent.run_stream('hello') as result:
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsJson(args_json='{"x": "hello"}'))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='hello world', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='hello world', timestamp=IsNow(tz=timezone.utc))]
+ ),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
assert await result.get_data() == snapshot(('hello world', 2))
assert result.all_messages() == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='ret_a', args=ArgsJson(args_json='{"x": "hello"}'))],
timestamp=IsNow(tz=timezone.utc),
),
- ToolReturn(tool_name='ret_a', content='hello world', timestamp=IsNow(tz=timezone.utc)),
- ToolReturn(
- tool_name='final_result',
- content='Final result processed.',
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[ToolReturnPart(tool_name='ret_a', content='hello world', timestamp=IsNow(tz=timezone.utc))]
+ ),
+ ModelRequest(
+ parts=[
+ ToolReturnPart(
+ tool_name='final_result',
+ content='Final result processed.',
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
ModelResponse(
parts=[
@@ -240,7 +259,7 @@ async def ret_a(x: str) -> str:
async def test_call_tool_empty():
- async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
+ async def stream_structured_function(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
yield {}
agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int])
@@ -251,7 +270,7 @@ async def stream_structured_function(_messages: list[Message], _: AgentInfo) ->
async def test_call_tool_wrong_name():
- async def stream_structured_function(_messages: list[Message], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
+ async def stream_structured_function(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
yield {0: DeltaToolCall(name='foobar', json_args='{}')}
agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=tuple[str, int])
@@ -265,14 +284,18 @@ async def ret_a(x: str) -> str: # pragma: no cover
pass
assert agent.last_run_messages == snapshot(
[
- UserPrompt(content='hello', timestamp=IsNow(tz=timezone.utc)),
+ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[ToolCallPart(tool_name='foobar', args=ArgsJson(args_json='{}'))],
timestamp=IsNow(tz=timezone.utc),
),
- RetryPrompt(
- content="Unknown tool name: 'foobar'. Available tools: ret_a, final_result",
- timestamp=IsNow(tz=timezone.utc),
+ ModelRequest(
+ parts=[
+ RetryPromptPart(
+ content="Unknown tool name: 'foobar'. Available tools: ret_a, final_result",
+ timestamp=IsNow(tz=timezone.utc),
+ )
+ ]
),
]
)
diff --git a/tests/test_tools.py b/tests/test_tools.py
index f5ddb1c37c..7873495f6d 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -9,7 +9,7 @@
from pydantic_core import PydanticSerializationError
from pydantic_ai import Agent, RunContext, Tool, UserError
-from pydantic_ai.messages import Message, ModelResponse, ToolCallPart
+from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
from pydantic_ai.tools import ToolDefinition
@@ -71,7 +71,7 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover
return f'{foo} {bar}'
-async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelResponse:
+async def get_json_schema(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
assert len(info.function_tools) == 1
r = info.function_tools[0]
return ModelResponse.from_text(pydantic_core.to_json(r).decode())
@@ -487,7 +487,7 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str:
def test_dynamic_tool_use_messages(set_event_loop: None):
- async def repeat_call_foobar(_messages: list[Message], info: AgentInfo) -> ModelResponse:
+ async def repeat_call_foobar(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
if info.function_tools:
tool = info.function_tools[0]
return ModelResponse.from_tool_call(ToolCallPart.from_dict(tool.name, {'x': 42, 'y': 'a'}))
@@ -506,15 +506,15 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str:
r = agent.run_sync('', deps=1)
assert r.data == snapshot('done')
- message_kinds = [m.message_kind for m in r.all_messages()]
- assert message_kinds == snapshot(
+ message_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in r.all_messages()]
+ assert message_part_kinds == snapshot(
[
- 'user-prompt',
- 'model-response',
- 'tool-return',
- 'model-response',
- 'tool-return',
- 'model-response',
+ ('request', ['user-prompt']),
+ ('response', ['tool-call']),
+ ('request', ['tool-return']),
+ ('response', ['tool-call']),
+ ('request', ['tool-return']),
+ ('response', ['text']),
]
)