Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ testcov: test
@echo "building coverage html"
@uv run coverage html

.PHONY: update-examples # update documentation examples
update-examples:
uv run -m pytest --update-examples

# `--no-strict` so you can build the docs without insiders packages
.PHONY: docs # Build the documentation
docs:
Expand Down
18 changes: 11 additions & 7 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,26 @@ print(dice_result.all_messages())
),
ModelStructuredResponse(
calls=[
ToolCall(tool_name='roll_die', args=ArgsDict(args_dict={}), tool_id=None)
ToolCall(
tool_name='roll_die', args=ArgsDict(args_dict={}), tool_call_id=None
)
],
timestamp=datetime.datetime(...),
role='model-structured-response',
),
ToolReturn(
tool_name='roll_die',
content='4',
tool_id=None,
tool_call_id=None,
timestamp=datetime.datetime(...),
role='tool-return',
),
ModelStructuredResponse(
calls=[
ToolCall(
tool_name='get_player_name', args=ArgsDict(args_dict={}), tool_id=None
tool_name='get_player_name',
args=ArgsDict(args_dict={}),
tool_call_id=None,
)
],
timestamp=datetime.datetime(...),
Expand All @@ -332,7 +336,7 @@ print(dice_result.all_messages())
ToolReturn(
tool_name='get_player_name',
content='Anne',
tool_id=None,
tool_call_id=None,
timestamp=datetime.datetime(...),
role='tool-return',
),
Expand Down Expand Up @@ -747,7 +751,7 @@ except UnexpectedModelBehavior as e:
ToolCall(
tool_name='calc_volume',
args=ArgsDict(args_dict={'size': 6}),
tool_id=None,
tool_call_id=None,
)
],
timestamp=datetime.datetime(...),
Expand All @@ -756,7 +760,7 @@ except UnexpectedModelBehavior as e:
RetryPrompt(
content='Please try again.',
tool_name='calc_volume',
tool_id=None,
tool_call_id=None,
timestamp=datetime.datetime(...),
role='retry-prompt',
),
Expand All @@ -765,7 +769,7 @@ except UnexpectedModelBehavior as e:
ToolCall(
tool_name='calc_volume',
args=ArgsDict(args_dict={'size': 6}),
tool_id=None,
tool_call_id=None,
)
],
timestamp=datetime.datetime(...),
Expand Down
4 changes: 2 additions & 2 deletions docs/testing-evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def test_forecast():
'forecast_date': '2024-01-01', # (8)!
}
),
tool_id=None,
tool_call_id=None,
)
],
timestamp=IsNow(tz=timezone.utc),
Expand All @@ -153,7 +153,7 @@ async def test_forecast():
ToolReturn(
tool_name='weather_forecast',
content='Sunny with a chance of rain',
tool_id=None,
tool_call_id=None,
timestamp=IsNow(tz=timezone.utc),
role='tool-return',
),
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def validate(
m = messages.RetryPrompt(content=r.message)
if tool_call is not None:
m.tool_name = tool_call.tool_name
m.tool_id = tool_call.tool_id
m.tool_call_id = tool_call.tool_call_id
raise ToolRetryError(m) from r
else:
return result_data
Expand Down Expand Up @@ -194,7 +194,7 @@ def validate(
m = messages.RetryPrompt(
tool_name=tool_call.tool_name,
content=e.errors(include_url=False),
tool_id=tool_call.tool_id,
tool_call_id=tool_call.tool_call_id,
)
raise ToolRetryError(m) from e
else:
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ async def _handle_model_response(
tool_return = _messages.ToolReturn(
tool_name=call.tool_name,
content='Final result processed.',
tool_id=call.tool_id,
tool_call_id=call.tool_call_id,
)
return _MarkFinalResult(result_data), [tool_return]

Expand Down Expand Up @@ -873,7 +873,7 @@ async def _handle_streamed_model_response(
tool_return = _messages.ToolReturn(
tool_name=call.tool_name,
content='Final result processed.',
tool_id=call.tool_id,
tool_call_id=call.tool_call_id,
)
return _MarkFinalResult(model_response), [tool_return]

Expand Down
16 changes: 8 additions & 8 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class ToolReturn:
"""The name of the "tool" was called."""
content: Any
"""The return value."""
tool_id: str | None = None
"""Optional tool identifier, this is used by some models including OpenAI."""
tool_call_id: str | None = None
"""Optional tool call identifier, this is used by some models including OpenAI."""
timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the tool returned."""
role: Literal['tool-return'] = 'tool-return'
Expand Down Expand Up @@ -100,8 +100,8 @@ class RetryPrompt:
"""
tool_name: str | None = None
"""The name of the tool that was called, if any."""
tool_id: str | None = None
"""The tool identifier, if any."""
tool_call_id: str | None = None
"""Optional tool call identifier, this is used by some models including OpenAI."""
timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the retry was triggered."""
role: Literal['retry-prompt'] = 'retry-prompt'
Expand Down Expand Up @@ -158,12 +158,12 @@ class ToolCall:

Either as JSON or a Python dictionary depending on how data was returned.
"""
tool_id: str | None = None
"""Optional tool identifier, this is used by some models including OpenAI."""
tool_call_id: str | None = None
"""Optional tool call identifier, this is used by some models including OpenAI."""

@classmethod
def from_json(cls, tool_name: str, args_json: str, tool_id: str | None = None) -> ToolCall:
return cls(tool_name, ArgsJson(args_json), tool_id)
def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> ToolCall:
return cls(tool_name, ArgsJson(args_json), tool_call_id)

@classmethod
def from_dict(cls, tool_name: str, args_dict: dict[str, Any]) -> ToolCall:
Expand Down
14 changes: 7 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
# ToolReturn ->
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_id(message),
tool_call_id=_guard_tool_call_id(message),
content=message.model_response_str(),
)
elif message.role == 'retry-prompt':
Expand All @@ -256,7 +256,7 @@ def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
else:
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_id(message),
tool_call_id=_guard_tool_call_id(message),
content=message.model_response(),
)
elif message.role == 'model-text-response':
Expand Down Expand Up @@ -365,16 +365,16 @@ def timestamp(self) -> datetime:
return self._timestamp


def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
assert t.tool_id is not None, f'Groq requires `tool_id` to be set: {t}'
return t.tool_id
def _guard_tool_call_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
assert t.tool_call_id is not None, f'Groq requires `tool_call_id` to be set: {t}'
return t.tool_call_id


def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
return chat.ChatCompletionMessageToolCallParam(
id=_guard_tool_id(t),
id=_guard_tool_call_id(t),
type='function',
function={'name': t.tool_name, 'arguments': t.args.args_json},
)
Expand Down
14 changes: 7 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
# ToolReturn ->
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_id(message),
tool_call_id=_guard_tool_call_id(message),
content=message.model_response_str(),
)
elif message.role == 'retry-prompt':
Expand All @@ -244,7 +244,7 @@ def _map_message(message: Message) -> chat.ChatCompletionMessageParam:
else:
return chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_id(message),
tool_call_id=_guard_tool_call_id(message),
content=message.model_response(),
)
elif message.role == 'model-text-response':
Expand Down Expand Up @@ -351,16 +351,16 @@ def timestamp(self) -> datetime:
return self._timestamp


def _guard_tool_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
"""Type guard that checks a `tool_id` is not None both for static typing and runtime."""
assert t.tool_id is not None, f'OpenAI requires `tool_id` to be set: {t}'
return t.tool_id
def _guard_tool_call_id(t: ToolCall | ToolReturn | RetryPrompt) -> str:
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
assert t.tool_call_id is not None, f'OpenAI requires `tool_call_id` to be set: {t}'
return t.tool_call_id


def _map_tool_call(t: ToolCall) -> chat.ChatCompletionMessageToolCallParam:
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
return chat.ChatCompletionMessageToolCallParam(
id=_guard_tool_id(t),
id=_guard_tool_call_id(t),
type='function',
function={'name': t.tool_name, 'arguments': t.args.args_json},
)
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes
return messages.ToolReturn(
tool_name=message.tool_name,
content=response_content,
tool_id=message.tool_id,
tool_call_id=message.tool_call_id,
)

def _call_args(
Expand Down Expand Up @@ -289,7 +289,7 @@ def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.To
return messages.RetryPrompt(
tool_name=call_message.tool_name,
content=content,
tool_id=call_message.tool_id,
tool_call_id=call_message.tool_call_id,
)


Expand Down
14 changes: 7 additions & 7 deletions tests/models/test_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ async def test_request_structured_response(allow_model_requests: None):
ToolCall(
tool_name='final_result',
args=ArgsJson(args_json='{"response": [1, 2, 123]}'),
tool_id='123',
tool_call_id='123',
)
],
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
),
ToolReturn(
tool_name='final_result',
content='Final result processed.',
tool_id='123',
tool_call_id='123',
timestamp=IsNow(tz=timezone.utc),
),
]
Expand Down Expand Up @@ -264,31 +264,31 @@ async def get_location(loc_name: str) -> str:
ToolCall(
tool_name='get_location',
args=ArgsJson(args_json='{"loc_name": "San Fransisco"}'),
tool_id='1',
tool_call_id='1',
)
],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
RetryPrompt(
tool_name='get_location',
content='Wrong location, please try again',
tool_id='1',
tool_call_id='1',
timestamp=IsNow(tz=timezone.utc),
),
ModelStructuredResponse(
calls=[
ToolCall(
tool_name='get_location',
args=ArgsJson(args_json='{"loc_name": "London"}'),
tool_id='2',
tool_call_id='2',
)
],
timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc),
),
ToolReturn(
tool_name='get_location',
content='{"lat": 51, "lng": 0}',
tool_id='2',
tool_call_id='2',
timestamp=IsNow(tz=timezone.utc),
),
ModelTextResponse(content='final response', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc)),
Expand Down Expand Up @@ -402,7 +402,7 @@ async def test_stream_structured(allow_model_requests: None):
ToolReturn(
tool_name='final_result',
content='Final result processed.',
tool_id=None,
tool_call_id=None,
timestamp=IsNow(tz=timezone.utc),
),
ModelStructuredResponse(
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_var_args(set_event_loop: None):
{
'tool_name': 'get_var_args',
'content': '{"args": [1, 2, 3]}',
'tool_id': None,
'tool_call_id': None,
'timestamp': IsStr() & IsNow(iso_string=True, tz=timezone.utc),
'role': 'tool-return',
}
Expand All @@ -225,8 +225,8 @@ def test_var_args(set_event_loop: None):
async def call_tool(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
if len(messages) == 1:
assert len(info.function_tools) == 1
tool_id = info.function_tools[0].name
return ModelStructuredResponse(calls=[ToolCall.from_json(tool_id, '{}')])
tool_call_id = info.function_tools[0].name
return ModelStructuredResponse(calls=[ToolCall.from_json(tool_call_id, '{}')])
else:
return ModelTextResponse('final response')

Expand Down
Loading
Loading