Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ lint:

.PHONY: typecheck-pyright
typecheck-pyright:
uv run pyright
@# PYRIGHT_PYTHON_IGNORE_WARNINGS avoids the overhead of making a request to github on every invocation
PYRIGHT_PYTHON_IGNORE_WARNINGS=1 uv run pyright

.PHONY: typecheck-mypy
typecheck-mypy:
Expand Down
168 changes: 164 additions & 4 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ print(dice_result.data)
```

1. The simplest way to register tools via the `Agent` constructor is to pass a list of functions, the function signature is inspected to determine if the tool takes [`RunContext`][pydantic_ai.tools.RunContext].
2. `agent_a` and `agent_b` are identical — but we can use [`Tool`][pydantic_ai.tools.Tool] to give more fine-grained control over how tools are defined, e.g. setting their name or description.
2. `agent_a` and `agent_b` are identical — but we can use [`Tool`][pydantic_ai.tools.Tool] to reuse tool definitions and give more fine-grained control over how tools are defined, e.g. setting their name or description, or using a custom [`prepare`](#tool-prepare) method.

_(This example is complete, it can be run "as is")_

Expand Down Expand Up @@ -459,10 +459,10 @@ def foobar(a: int, b: str, c: dict[str, list[float]]) -> str:


def print_schema(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
tool = info.function_tools['foobar']
tool = info.function_tools[0]
print(tool.description)
#> Get me foobar.
print(tool.json_schema)
print(tool.parameters_json_schema)
"""
{
'description': 'Get me foobar.',
Expand Down Expand Up @@ -491,7 +491,167 @@ _(This example is complete, it can be run "as is")_

The return type of tool can be anything which Pydantic can serialize to JSON as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.

If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example)
If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object.

Here's an example, we use [`TestModel.agent_model_function_tools`][pydantic_ai.models.test.TestModel.agent_model_function_tools] to inspect the tool schema that would be passed to the model.

```py title="single_parameter_tool.py"
from pydantic import BaseModel

from pydantic_ai import Agent
from pydantic_ai.models.test import TestModel

agent = Agent()


class Foobar(BaseModel):
"""This is a Foobar"""

x: int
y: str
z: float = 3.14


@agent.tool_plain
def foobar(f: Foobar) -> str:
return str(f)


test_model = TestModel()
result = agent.run_sync('hello', model=test_model)
print(result.data)
#> {"foobar":"x=0 y='a' z=3.14"}
print(test_model.agent_model_function_tools)
"""
[
ToolDefinition(
name='foobar',
description='',
parameters_json_schema={
'description': 'This is a Foobar',
'properties': {
'x': {'title': 'X', 'type': 'integer'},
'y': {'title': 'Y', 'type': 'string'},
'z': {'default': 3.14, 'title': 'Z', 'type': 'number'},
},
'required': ['x', 'y'],
'title': 'Foobar',
'type': 'object',
},
outer_typed_dict_key=None,
)
]
"""
```

_(This example is complete, it can be run "as is")_

### Dynamic Function tools {#tool-prepare}

Tools can optionally be defined with another function: `prepare`, which is called at each step of a run to
customize the definition of the tool passed to the model, or omit the tool completely from that step.

A `prepare` method can be registered via the `prepare` kwarg to any of the tool registration mechanisms:

* [`@agent.tool`][pydantic_ai.Agent.tool] decorator
* [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator
* [`Tool`][pydantic_ai.tools.Tool] dataclass

The `prepare` method, should be of type [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc], a function which takes [`RunContext`][pydantic_ai.tools.RunContext] and a pre-built [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and should either return that `ToolDefinition` with or without modifying it, return a new `ToolDefinition`, or return `None` to indicate this tools should not be registered for that step.

Here's a simple `prepare` method that only includes the tool if the value of the dependency is `42`.

As with the previous example, we use [`TestModel`][pydantic_ai.models.test.TestModel] to demonstrate the behavior without calling a real model.

```py title="tool_only_if_42.py"
from typing import Union

from pydantic_ai import Agent, RunContext
from pydantic_ai.tools import ToolDefinition

agent = Agent('test')


async def only_if_42(
ctx: RunContext[int], tool_def: ToolDefinition
) -> Union[ToolDefinition, None]:
if ctx.deps == 42:
return tool_def


@agent.tool(prepare=only_if_42)
def hitchhiker(ctx: RunContext[int], answer: str) -> str:
return f'{ctx.deps} {answer}'


result = agent.run_sync('testing...', deps=41)
print(result.data)
#> success (no tool calls)
result = agent.run_sync('testing...', deps=42)
print(result.data)
#> {"hitchhiker":"42 a"}
```

_(This example is complete, it can be run "as is")_

Here's a more complex example where we change the description of the `name` parameter to based on the value of `deps`

For the sake of variation, we create this tool using the [`Tool`][pydantic_ai.tools.Tool] dataclass.

```py title="customize_name.py"
from __future__ import annotations

from typing import Literal

from pydantic_ai import Agent, RunContext
from pydantic_ai.models.test import TestModel
from pydantic_ai.tools import Tool, ToolDefinition


def greet(name: str) -> str:
return f'hello {name}'


async def prepare_greet(
ctx: RunContext[Literal['human', 'machine']], tool_def: ToolDefinition
) -> ToolDefinition | None:
d = f'Name of the {ctx.deps} to greet.'
tool_def.parameters_json_schema['properties']['name']['description'] = d
return tool_def


greet_tool = Tool(greet, prepare=prepare_greet)
test_model = TestModel()
agent = Agent(test_model, tools=[greet_tool], deps_type=Literal['human', 'machine'])

result = agent.run_sync('testing...', deps='human')
print(result.data)
#> {"greet":"hello a"}
print(test_model.agent_model_function_tools)
"""
[
ToolDefinition(
name='greet',
description='',
parameters_json_schema={
'properties': {
'name': {
'title': 'Name',
'type': 'string',
'description': 'Name of the human to greet.',
}
},
'required': ['name'],
'type': 'object',
'additionalProperties': False,
},
outer_typed_dict_key=None,
)
]
"""
```

_(This example is complete, it can be run "as is")_

## Reflection and self-correction

Expand Down
2 changes: 0 additions & 2 deletions docs/testing-evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,6 @@ def call_weather_forecast( # (1)!
) -> ModelAnyResponse:
if len(messages) == 2:
# first call, call the weather forecast tool
assert set(info.function_tools.keys()) == {'weather_forecast'}

user_prompt = messages[1]
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
assert m is not None
Expand Down
10 changes: 6 additions & 4 deletions pydantic_ai_slim/pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from pydantic_core import SchemaValidator, core_schema

from ._griffe import doc_descriptions
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
from ._utils import check_object_json_schema, is_model_like

if TYPE_CHECKING:
pass
from .tools import ObjectJsonSchema


__all__ = 'function_schema', 'LazyTypeAdapter'
Expand Down Expand Up @@ -168,11 +168,13 @@ def takes_ctx(function: Callable[..., Any]) -> bool:
"""
sig = signature(function)
try:
_, first_param = next(iter(sig.parameters.items()))
first_param_name = next(iter(sig.parameters.keys()))
except StopIteration:
return False
else:
return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
type_hints = _typing_extra.get_function_type_hints(function)
annotation = type_hints[first_param_name]
return annotation is not sig.empty and _is_call_ctx(annotation)


def _build_schema(
Expand Down
40 changes: 18 additions & 22 deletions pydantic_ai_slim/pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .exceptions import ModelRetry
from .messages import ModelStructuredResponse, ToolCall
from .result import ResultData
from .tools import AgentDeps, ResultValidatorFunc, RunContext
from .tools import AgentDeps, ResultValidatorFunc, RunContext, ToolDefinition


@dataclass
Expand Down Expand Up @@ -94,10 +94,7 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
allow_text_result = False

def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]:
return cast(
ResultTool[ResultData],
ResultTool.build(a, tool_name_, description, multiple), # pyright: ignore[reportUnknownMemberType]
)
return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))

tools: dict[str, ResultTool[ResultData]] = {}
if args := get_union_args(response_type):
Expand All @@ -121,38 +118,38 @@ def tool_names(self) -> list[str]:
"""Return the names of the tools."""
return list(self.tools.keys())

def tool_defs(self) -> list[ToolDefinition]:
"""Get tool definitions to register with the model."""
return [t.tool_def for t in self.tools.values()]


DEFAULT_DESCRIPTION = 'The final response which ends this conversation'


@dataclass
@dataclass(init=False)
class ResultTool(Generic[ResultData]):
name: str
description: str
tool_def: ToolDefinition
type_adapter: TypeAdapter[Any]
json_schema: _utils.ObjectJsonSchema
outer_typed_dict_key: str | None

@classmethod
def build(cls, response_type: type[ResultData], name: str, description: str | None, multiple: bool) -> Self | None:
def __init__(self, response_type: type[ResultData], name: str, description: str | None, multiple: bool):
"""Build a ResultTool dataclass from a response type."""
assert response_type is not str, 'ResultTool does not support str as a response type'

if _utils.is_model_like(response_type):
type_adapter = TypeAdapter(response_type)
self.type_adapter = TypeAdapter(response_type)
outer_typed_dict_key: str | None = None
# noinspection PyArgumentList
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
else:
response_data_typed_dict = TypedDict('response_data_typed_dict', {'response': response_type}) # noqa
type_adapter = TypeAdapter(response_data_typed_dict)
self.type_adapter = TypeAdapter(response_data_typed_dict)
outer_typed_dict_key = 'response'
# noinspection PyArgumentList
json_schema = _utils.check_object_json_schema(type_adapter.json_schema())
parameters_json_schema = _utils.check_object_json_schema(self.type_adapter.json_schema())
# including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM
json_schema.pop('title')
parameters_json_schema.pop('title')

if json_schema_description := json_schema.pop('description', None):
if json_schema_description := parameters_json_schema.pop('description', None):
if description is None:
tool_description = json_schema_description
else:
Expand All @@ -162,11 +159,10 @@ def build(cls, response_type: type[ResultData], name: str, description: str | No
if multiple:
tool_description = f'{union_arg_name(response_type)}: {tool_description}'

return cls(
self.tool_def = ToolDefinition(
name=name,
description=tool_description,
type_adapter=type_adapter,
json_schema=json_schema,
parameters_json_schema=parameters_json_schema,
outer_typed_dict_key=outer_typed_dict_key,
)

Expand Down Expand Up @@ -204,7 +200,7 @@ def validate(
else:
raise
else:
if k := self.outer_typed_dict_key:
if k := self.tool_def.outer_typed_dict_key:
result = result[k]
return result

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __post_init__(self):

async def run(self, deps: AgentDeps) -> str:
if self._takes_ctx:
args = (RunContext(deps, 0, None),)
args = (RunContext(deps, 0),)
else:
args = ()

Expand Down
17 changes: 11 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from datetime import datetime, timezone
from functools import partial
from types import GenericAlias
from typing import Any, Callable, Generic, TypeVar, Union, cast, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast, overload

from pydantic import BaseModel
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, is_typeddict

if TYPE_CHECKING:
from .tools import ObjectJsonSchema

_P = ParamSpec('_P')
_R = TypeVar('_R')

Expand All @@ -39,10 +42,6 @@ def is_model_like(type_: Any) -> bool:
)


# With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_items=Any`
ObjectJsonSchema: TypeAlias = dict[str, Any]


def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
from .exceptions import UserError

Expand Down Expand Up @@ -127,6 +126,12 @@ def is_left(self) -> bool:
def whichever(self) -> Left | Right:
return self._left.value if self._left is not None else self.right

def __repr__(self):
if left := self._left:
return f'Either(left={left.value!r})'
else:
return f'Either(right={self.right!r})'


@asynccontextmanager
async def group_by_temporal(
Expand Down Expand Up @@ -218,7 +223,7 @@ async def async_iter_groups() -> AsyncIterator[list[T]]:

try:
yield async_iter_groups()
finally:
finally: # pragma: no cover
# after iteration if a tasks still exists, cancel it, this will only happen if an error occurred
if task:
task.cancel('Cancelling due to error in iterator')
Expand Down
Loading
Loading