diff --git a/docs/agents.md b/docs/agents.md
index 5564cbc2c0..c75869a547 100644
--- a/docs/agents.md
+++ b/docs/agents.md
@@ -49,7 +49,7 @@ print(result.data)
```
1. Create an agent, which expects an integer dependency and returns a boolean result. This agent will have type `#!python Agent[int, bool]`.
-2. Define a tool that checks if the square is a winner. Here [`RunContext`][pydantic_ai.dependencies.RunContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error.
+2. Define a tool that checks if the square is a winner. Here [`RunContext`][pydantic_ai.tools.RunContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error.
3. In reality, you might want to use a random number here e.g. `random.randint(0, 36)`.
4. `result.data` will be a boolean indicating if the square is a winner. Pydantic performs the result validation, it'll be typed as a `bool` since its type is derived from the `result_type` generic parameter of the agent.
@@ -161,7 +161,7 @@ def foobar(x: bytes) -> None:
pass
-result = agent.run_sync('Does their name start with "A"?', deps=User('Adam'))
+result = agent.run_sync('Does their name start with "A"?', deps=User('Anne'))
foobar(result.data) # (3)!
```
@@ -222,7 +222,7 @@ print(result.data)
1. The agent expects a string dependency.
2. Static system prompt defined at agent creation time.
-3. Dynamic system prompt defined via a decorator with [`RunContext`][pydantic_ai.dependencies.RunContext], this is called just after `run_sync`, not when the agent is created, so can benefit from runtime information like the dependencies used on that run.
+3. Dynamic system prompt defined via a decorator with [`RunContext`][pydantic_ai.tools.RunContext], this is called just after `run_sync`, not when the agent is created, so can benefit from runtime information like the dependencies used on that run.
4. Another dynamic system prompt, system prompts don't have to have the `RunContext` parameter.
_(This example is complete, it can be run "as is")_
@@ -238,12 +238,13 @@ They're useful when it is impractical or impossible to put all the context an ag
The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58))
-There are two different decorator functions to register tools:
+There are a number of ways to register tools with an agent:
-1. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that need access to the agent [context][pydantic_ai.dependencies.RunContext]
-2. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that do not need access to the agent [context][pydantic_ai.dependencies.RunContext]
+* via the [`@agent.tool`][pydantic_ai.Agent.tool] decorator — for tools that need access to the agent [context][pydantic_ai.tools.RunContext]
+* via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext]
+* via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool]
-`@agent.tool` is the default since in the majority of cases tools will need access to the agent context.
+`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent context.
Here's an example using both:
@@ -275,9 +276,9 @@ def get_player_name(ctx: RunContext[str]) -> str:
return ctx.deps
-dice_result = agent.run_sync('My guess is 4', deps='Adam') # (5)!
+dice_result = agent.run_sync('My guess is 4', deps='Anne') # (5)!
print(dice_result.data)
-#> Congratulations Adam, you guessed correctly! You're a winner!
+#> Congratulations Anne, you guessed correctly! You're a winner!
```
1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model.
@@ -330,13 +331,13 @@ print(dice_result.all_messages())
),
ToolReturn(
tool_name='get_player_name',
- content='Adam',
+ content='Anne',
tool_id=None,
timestamp=datetime.datetime(...),
role='tool-return',
),
ModelTextResponse(
- content="Congratulations Adam, you guessed correctly! You're a winner!",
+ content="Congratulations Anne, you guessed correctly! You're a winner!",
timestamp=datetime.datetime(...),
role='model-text-response',
),
@@ -370,16 +371,59 @@ sequenceDiagram
deactivate LLM
activate Agent
Note over Agent: Retrieves player name
- Agent -->> LLM: ToolReturn
"Adam"
+ Agent -->> LLM: ToolReturn
"Anne"
deactivate Agent
activate LLM
Note over LLM: LLM constructs final response
- LLM ->> Agent: ModelTextResponse
"Congratulations Adam, ..."
+ LLM ->> Agent: ModelTextResponse
"Congratulations Anne, ..."
deactivate LLM
Note over Agent: Game session complete
```
+### Registering Function Tools via kwarg
+
+As well as using the decorators, we can register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to re-use tools, and can also give more fine-grained control over the tools.
+
+```py title="dice_game_tool_kwarg.py"
+import random
+
+from pydantic_ai import Agent, RunContext, Tool
+
+
+def roll_die() -> str:
+ """Roll a six-sided die and return the result."""
+ return str(random.randint(1, 6))
+
+
+def get_player_name(ctx: RunContext[str]) -> str:
+ """Get the player's name."""
+ return ctx.deps
+
+
+agent_a = Agent(
+ 'gemini-1.5-flash',
+ deps_type=str,
+ tools=[roll_die, get_player_name], # (1)!
+)
+agent_b = Agent(
+ 'gemini-1.5-flash',
+ deps_type=str,
+ tools=[ # (2)!
+ Tool(roll_die, takes_ctx=False),
+ Tool(get_player_name, takes_ctx=True),
+ ],
+)
+dice_result = agent_b.run_sync('My guess is 4', deps='Anne')
+print(dice_result.data)
+#> Congratulations Anne, you guessed correctly! You're a winner!
+```
+
+1. The simplest way to register tools via the `Agent` constructor is to pass a list of functions, the function signature is inspected to determine if the tool takes [`RunContext`][pydantic_ai.tools.RunContext].
+2. `agent_a` and `agent_b` are identical — but we can use [`Tool`][pydantic_ai.tools.Tool] to give more fine-grained control over how tools are defined, e.g. setting their name or description.
+
+_(This example is complete, it can be run "as is")_
+
### Function Tools vs. Structured Results
As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and return a result.
@@ -445,7 +489,7 @@ agent.run_sync('hello', model=FunctionModel(print_schema))
_(This example is complete, it can be run "as is")_
-The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
+The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.tools.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example)
@@ -456,7 +500,7 @@ Validation errors from both function tool parameter validation and [structured r
You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [tool](#function-tools) or [result validator function](results.md#result-validators-functions) to tell the model it should retry generating a response.
- The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific tool][pydantic_ai.Agent.tool], or a [result validator][pydantic_ai.Agent.__init__].
-- You can access the current retry count from within a tool or result validator via [`ctx.retry`][pydantic_ai.dependencies.RunContext].
+- You can access the current retry count from within a tool or result validator via [`ctx.retry`][pydantic_ai.tools.RunContext].
Here's an example:
diff --git a/docs/api/dependencies.md b/docs/api/dependencies.md
deleted file mode 100644
index 9f49436a0a..0000000000
--- a/docs/api/dependencies.md
+++ /dev/null
@@ -1,3 +0,0 @@
-# `pydantic_ai.dependencies`
-
-::: pydantic_ai.dependencies
diff --git a/docs/api/tools.md b/docs/api/tools.md
new file mode 100644
index 0000000000..3dcf88d566
--- /dev/null
+++ b/docs/api/tools.md
@@ -0,0 +1,3 @@
+# `pydantic_ai.tools`
+
+::: pydantic_ai.tools
diff --git a/docs/dependencies.md b/docs/dependencies.md
index 129b58228a..b088ba94bc 100644
--- a/docs/dependencies.md
+++ b/docs/dependencies.md
@@ -51,7 +51,7 @@ _(This example is complete, it can be run "as is")_
## Accessing Dependencies
-Dependencies are accessed through the [`RunContext`][pydantic_ai.dependencies.RunContext] type, this should be the first parameter of system prompt functions etc.
+Dependencies are accessed through the [`RunContext`][pydantic_ai.tools.RunContext] type, this should be the first parameter of system prompt functions etc.
```py title="system_prompt_dependencies.py" hl_lines="20-27"
@@ -92,10 +92,10 @@ async def main():
#> Did you hear about the toothpaste scandal? They called it Colgate.
```
-1. [`RunContext`][pydantic_ai.dependencies.RunContext] may optionally be passed to a [`system_prompt`][pydantic_ai.Agent.system_prompt] function as the only argument.
-2. [`RunContext`][pydantic_ai.dependencies.RunContext] is parameterized with the type of the dependencies, if this type is incorrect, static type checkers will raise an error.
-3. Access dependencies through the [`.deps`][pydantic_ai.dependencies.RunContext.deps] attribute.
-4. Access dependencies through the [`.deps`][pydantic_ai.dependencies.RunContext.deps] attribute.
+1. [`RunContext`][pydantic_ai.tools.RunContext] may optionally be passed to a [`system_prompt`][pydantic_ai.Agent.system_prompt] function as the only argument.
+2. [`RunContext`][pydantic_ai.tools.RunContext] is parameterized with the type of the dependencies, if this type is incorrect, static type checkers will raise an error.
+3. Access dependencies through the [`.deps`][pydantic_ai.tools.RunContext.deps] attribute.
+4. Access dependencies through the [`.deps`][pydantic_ai.tools.RunContext.deps] attribute.
_(This example is complete, it can be run "as is")_
diff --git a/docs/index.md b/docs/index.md
index f6b233b71a..69d7e14f52 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -125,8 +125,8 @@ async def main():
2. Here we configure the agent to use [OpenAI's GPT-4o model](api/models/openai.md), you can also set the model when running the agent.
3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [tool](agents.md#function-tools) functions. PydanticAI's system of dependency injection provides a [type-safe](agents.md#static-type-checking) way to customise the behavior of your agents, and can be especially useful when running [unit tests](testing-evals.md) and evals.
4. Static [system prompts](agents.md#system-prompts) can be registered with the [`system_prompt` keyword argument][pydantic_ai.Agent.__init__] to the agent.
-5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`RunContext`][pydantic_ai.dependencies.RunContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it.
-6. [`tool`](agents.md#function-tools) let you register functions which the LLM may call while responding to a user. Again, dependencies are carried via [`RunContext`][pydantic_ai.dependencies.RunContext], any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry.
+5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`RunContext`][pydantic_ai.tools.RunContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it.
+6. [`tool`](agents.md#function-tools) let you register functions which the LLM may call while responding to a user. Again, dependencies are carried via [`RunContext`][pydantic_ai.tools.RunContext], any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry.
7. The docstring of a tool is also passed to the LLM as the description of the tool. Parameter descriptions are [extracted](agents.md#function-tools-and-schema) from the docstring and added to the parameter schema sent to the LLM.
8. [Run the agent](agents.md#running-agents) asynchronously, conducting a conversation with the LLM until a final response is reached. Even in this fairly simple case, the agent will exchange multiple messages with the LLM as tools are called to retrieve a result.
9. The response from the agent will, be guaranteed to be a `SupportResult`, if validation fails [reflection](agents.md#reflection-and-self-correction) will mean the agent is prompted to try again.
diff --git a/docs/install.md b/docs/install.md
index f22f3dd9fa..9d1b5ad268 100644
--- a/docs/install.md
+++ b/docs/install.md
@@ -40,7 +40,7 @@ To run the examples, follow instructions in the [examples docs](examples/index.m
## Slim Install
-If you know which model you're going to use and want to avoid installing superfluous package, you can use the [`pydantic-ai-slim`](https://pypi.org/project/pydantic-ai-slim/) package.
+If you know which model you're going to use and want to avoid installing superfluous packages, you can use the [`pydantic-ai-slim`](https://pypi.org/project/pydantic-ai-slim/) package.
If you're using just [`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel], run:
diff --git a/mkdocs.yml b/mkdocs.yml
index 589b1d2604..d46b774a93 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -31,9 +31,9 @@ nav:
- examples/chat-app.md
- API Reference:
- api/agent.md
+ - api/tools.md
- api/result.md
- api/messages.md
- - api/dependencies.md
- api/exceptions.md
- api/models/base.md
- api/models/openai.md
diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py
index c4591d2a17..119699afc6 100644
--- a/pydantic_ai_slim/pydantic_ai/__init__.py
+++ b/pydantic_ai_slim/pydantic_ai/__init__.py
@@ -1,8 +1,8 @@
from importlib.metadata import version
from .agent import Agent
-from .dependencies import RunContext
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
+from .tools import RunContext, Tool
-__all__ = 'Agent', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
+__all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__'
__version__ = version('pydantic_ai_slim')
diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py
index a948acf051..4aecd5fcf1 100644
--- a/pydantic_ai_slim/pydantic_ai/_pydantic.py
+++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py
@@ -6,7 +6,7 @@
from __future__ import annotations as _annotations
from inspect import Parameter, signature
-from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin
+from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin
from pydantic import ConfigDict, TypeAdapter
from pydantic._internal import _decorators, _generate_schema, _typing_extra
@@ -20,8 +20,7 @@
from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like
if TYPE_CHECKING:
- from . import _tool
- from .dependencies import AgentDeps, ToolParams
+ pass
__all__ = 'function_schema', 'LazyTypeAdapter'
@@ -39,17 +38,16 @@ class FunctionSchema(TypedDict):
var_positional_field: str | None
-def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]) -> FunctionSchema: # noqa: C901
+def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSchema: # noqa: C901
"""Build a Pydantic validator and JSON schema from a tool function.
Args:
- either_function: The function to build a validator and JSON schema for.
+ function: The function to build a validator and JSON schema for.
+ takes_ctx: Whether the function takes a `RunContext` first argument.
Returns:
A `FunctionSchema` instance.
"""
- function = either_function.whichever()
- takes_ctx = either_function.is_left()
config = ConfigDict(title=function.__name__)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper)
@@ -78,13 +76,13 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
if index == 0 and takes_ctx:
if not _is_call_ctx(annotation):
- errors.append('First argument must be a RunContext instance when using `.tool`')
+ errors.append('First parameter of tools that take context must be annotated with RunContext[...]')
continue
elif not takes_ctx and _is_call_ctx(annotation):
- errors.append('RunContext instance can only be used with `.tool`')
+ errors.append('RunContext annotations can only be used with tools that take context')
continue
elif index != 0 and _is_call_ctx(annotation):
- errors.append('RunContext instance can only be used as the first argument')
+ errors.append('RunContext annotations can only be used as the first argument')
continue
field_name = p.name
@@ -159,6 +157,24 @@ def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]
)
+def takes_ctx(function: Callable[..., Any]) -> bool:
+ """Check if a function takes a `RunContext` first argument.
+
+ Args:
+ function: The function to check.
+
+ Returns:
+ `True` if the function takes a `RunContext` as first argument, `False` otherwise.
+ """
+ sig = signature(function)
+ try:
+ _, first_param = next(iter(sig.parameters.items()))
+ except StopIteration:
+ return False
+ else:
+ return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation)
+
+
def _build_schema(
fields: dict[str, core_schema.TypedDictField],
var_kwargs_schema: core_schema.CoreSchema | None,
@@ -191,7 +207,7 @@ def _build_schema(
def _is_call_ctx(annotation: Any) -> bool:
- from .dependencies import RunContext
+ from .tools import RunContext
return annotation is RunContext or (
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py
index cbd2880810..3622f8664e 100644
--- a/pydantic_ai_slim/pydantic_ai/_result.py
+++ b/pydantic_ai_slim/pydantic_ai/_result.py
@@ -11,10 +11,10 @@
from typing_extensions import Self, TypeAliasType, TypedDict
from . import _utils, messages
-from .dependencies import AgentDeps, ResultValidatorFunc, RunContext
from .exceptions import ModelRetry
from .messages import ModelStructuredResponse, ToolCall
from .result import ResultData
+from .tools import AgentDeps, ResultValidatorFunc, RunContext
@dataclass
diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py
index 04096c4e02..7f05079a13 100644
--- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py
+++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py
@@ -6,7 +6,7 @@
from typing import Any, Callable, Generic, cast
from . import _utils
-from .dependencies import AgentDeps, RunContext, SystemPromptFunc
+from .tools import AgentDeps, RunContext, SystemPromptFunc
@dataclass
diff --git a/pydantic_ai_slim/pydantic_ai/_tool.py b/pydantic_ai_slim/pydantic_ai/_tool.py
deleted file mode 100644
index 9d82360bde..0000000000
--- a/pydantic_ai_slim/pydantic_ai/_tool.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from __future__ import annotations as _annotations
-
-import inspect
-from collections.abc import Awaitable
-from dataclasses import dataclass, field
-from typing import Any, Callable, Generic, cast
-
-from pydantic import ValidationError
-from pydantic_core import SchemaValidator
-
-from . import _pydantic, _utils, messages
-from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
-from .exceptions import ModelRetry, UnexpectedModelBehavior
-
-# Usage `ToolEitherFunc[AgentDependencies, P]`
-ToolEitherFunc = _utils.Either[ToolContextFunc[AgentDeps, ToolParams], ToolPlainFunc[ToolParams]]
-
-
-@dataclass(init=False)
-class Tool(Generic[AgentDeps, ToolParams]):
- """A tool function for an agent."""
-
- name: str
- description: str
- function: ToolEitherFunc[AgentDeps, ToolParams] = field(repr=False)
- is_async: bool
- single_arg_name: str | None
- positional_fields: list[str]
- var_positional_field: str | None
- validator: SchemaValidator = field(repr=False)
- json_schema: _utils.ObjectJsonSchema
- max_retries: int
- _current_retry: int = 0
- outer_typed_dict_key: str | None = None
-
- def __init__(self, function: ToolEitherFunc[AgentDeps, ToolParams], retries: int):
- """Build a Tool dataclass from a function."""
- self.function = function
- # noinspection PyTypeChecker
- f = _pydantic.function_schema(function)
- raw_function = function.whichever()
- self.name = raw_function.__name__
- self.description = f['description']
- self.is_async = inspect.iscoroutinefunction(raw_function)
- self.single_arg_name = f['single_arg_name']
- self.positional_fields = f['positional_fields']
- self.var_positional_field = f['var_positional_field']
- self.validator = f['validator']
- self.json_schema = f['json_schema']
- self.max_retries = retries
-
- def reset(self) -> None:
- """Reset the current retry count."""
- self._current_retry = 0
-
- async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
- """Run the tool function asynchronously."""
- try:
- if isinstance(message.args, messages.ArgsJson):
- args_dict = self.validator.validate_json(message.args.args_json)
- else:
- args_dict = self.validator.validate_python(message.args.args_dict)
- except ValidationError as e:
- return self._on_error(e, message)
-
- args, kwargs = self._call_args(deps, args_dict, message)
- try:
- if self.is_async:
- function = cast(Callable[[Any], Awaitable[str]], self.function.whichever())
- response_content = await function(*args, **kwargs)
- else:
- function = cast(Callable[[Any], str], self.function.whichever())
- response_content = await _utils.run_in_executor(function, *args, **kwargs)
- except ModelRetry as e:
- return self._on_error(e, message)
-
- self._current_retry = 0
- return messages.ToolReturn(
- tool_name=message.tool_name,
- content=response_content,
- tool_id=message.tool_id,
- )
-
- def _call_args(
- self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
- ) -> tuple[list[Any], dict[str, Any]]:
- if self.single_arg_name:
- args_dict = {self.single_arg_name: args_dict}
-
- args = [RunContext(deps, self._current_retry, message.tool_name)] if self.function.is_left() else []
- for positional_field in self.positional_fields:
- args.append(args_dict.pop(positional_field))
- if self.var_positional_field:
- args.extend(args_dict.pop(self.var_positional_field))
-
- return args, args_dict
-
- def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
- self._current_retry += 1
- if self._current_retry > self.max_retries:
- # TODO custom error with details of the tool
- raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
- else:
- if isinstance(exc, ValidationError):
- content = exc.errors(include_url=False)
- else:
- content = exc.message
- return messages.RetryPrompt(
- tool_name=call_message.tool_name,
- content=content,
- tool_id=call_message.tool_id,
- )
diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py
index 13cb4c4959..5f8e2a8d78 100644
--- a/pydantic_ai_slim/pydantic_ai/agent.py
+++ b/pydantic_ai_slim/pydantic_ai/agent.py
@@ -1,6 +1,7 @@
from __future__ import annotations as _annotations
import asyncio
+import dataclasses
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass, field
@@ -12,15 +13,14 @@
from . import (
_result,
_system_prompt,
- _tool as _r,
_utils,
exceptions,
messages as _messages,
models,
result,
)
-from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc
from .result import ResultData
+from .tools import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams
__all__ = ('Agent',)
@@ -34,7 +34,7 @@
class Agent(Generic[AgentDeps, ResultData]):
"""Class for defining "agents" - a way to have a specific type of "conversation" with an LLM.
- Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.dependencies.AgentDeps]
+ Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.tools.AgentDeps]
and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData].
By default, if neither generic parameter is customised, agents have type `Agent[None, str]`.
@@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]):
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False)
_allow_text_result: bool = field(repr=False)
_system_prompts: tuple[str, ...] = field(repr=False)
- _function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False)
+ _function_tools: dict[str, Tool[AgentDeps]] = field(repr=False)
_default_retries: int = field(repr=False)
_system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False)
_deps_type: type[AgentDeps] = field(repr=False)
@@ -83,6 +83,7 @@ def __init__(
result_tool_name: str = 'final_result',
result_tool_description: str | None = None,
result_retries: int | None = None,
+ tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (),
defer_model_check: bool = False,
):
"""Create an agent.
@@ -101,6 +102,8 @@ def __init__(
result_tool_name: The name of the tool to use for the final result.
result_tool_description: The description of the final result tool.
result_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
+ tools: Tools to register with the agent, you can also register tools via the decorators
+ [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately,
which checks for the necessary environment variables. Set this to `false`
@@ -119,9 +122,11 @@ def __init__(
self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
- self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {}
- self._deps_type = deps_type
+ self._function_tools = {}
self._default_retries = retries
+ for tool in tools:
+ self._register_tool(Tool.infer(tool))
+ self._deps_type = deps_type
self._system_prompt_functions = []
self._max_result_retries = result_retries if result_retries is not None else retries
self._current_result_retry = 0
@@ -355,7 +360,7 @@ def system_prompt(
) -> _system_prompt.SystemPromptFunc[AgentDeps]:
"""Decorator to register a system prompt function.
- Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's only argument.
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's only argument.
Can decorate a sync or async functions.
Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure
@@ -406,7 +411,7 @@ def result_validator(
) -> _result.ResultValidatorFunc[AgentDeps, ResultData]:
"""Decorator to register a result validator function.
- Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's first argument.
+ Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's first argument.
Can decorate a sync or async functions.
Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure
@@ -439,22 +444,22 @@ async def result_validator_deps(ctx: RunContext[str], data: str) -> str:
return func
@overload
- def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ...
+ def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ...
@overload
def tool(
self, /, *, retries: int | None = None
- ) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ...
+ ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ...
def tool(
self,
- func: ToolContextFunc[AgentDeps, ToolParams] | None = None,
+ func: ToolFuncContext[AgentDeps, ToolParams] | None = None,
/,
*,
retries: int | None = None,
) -> Any:
"""Decorator to register a tool function which takes
- [`RunContext`][pydantic_ai.dependencies.RunContext] as its first argument.
+ [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
Can decorate a sync or async functions.
@@ -491,27 +496,27 @@ async def spam(ctx: RunContext[str], y: float) -> float:
if func is None:
def tool_decorator(
- func_: ToolContextFunc[AgentDeps, ToolParams],
- ) -> ToolContextFunc[AgentDeps, ToolParams]:
+ func_: ToolFuncContext[AgentDeps, ToolParams],
+ ) -> ToolFuncContext[AgentDeps, ToolParams]:
# noinspection PyTypeChecker
- self._register_tool(_utils.Either(left=func_), retries)
+ self._register_function(func_, True, retries)
return func_
return tool_decorator
else:
# noinspection PyTypeChecker
- self._register_tool(_utils.Either(left=func), retries)
+ self._register_function(func, True, retries)
return func
@overload
- def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ...
+ def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ...
@overload
def tool_plain(
self, /, *, retries: int | None = None
- ) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ...
+ ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
- def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
+ def tool_plain(self, func: ToolFuncPlain[ToolParams] | None = None, /, *, retries: int | None = None) -> Any:
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
Can decorate a sync or async functions.
@@ -548,28 +553,34 @@ async def spam(ctx: RunContext[str]) -> float:
"""
if func is None:
- def tool_decorator(
- func_: ToolPlainFunc[ToolParams],
- ) -> ToolPlainFunc[ToolParams]:
+ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]:
# noinspection PyTypeChecker
- self._register_tool(_utils.Either(right=func_), retries)
+ self._register_function(func_, False, retries)
return func_
return tool_decorator
else:
- self._register_tool(_utils.Either(right=func), retries)
+ self._register_function(func, False, retries)
return func
- def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None:
- """Private utility to register a tool function."""
+ def _register_function(
+ self, func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None
+ ) -> None:
+ """Private utility to register a function as a tool."""
retries_ = retries if retries is not None else self._default_retries
- tool = _r.Tool[AgentDeps, ToolParams](func, retries_)
+ tool = Tool(func, takes_ctx, max_retries=retries_)
+ self._register_tool(tool)
- if self._result_schema and tool.name in self._result_schema.tools:
- raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}')
+ def _register_tool(self, tool: Tool[AgentDeps]) -> None:
+ """Private utility to register a tool instance."""
+ if tool.max_retries is None:
+ tool = dataclasses.replace(tool, max_retries=self._default_retries)
if tool.name in self._function_tools:
- raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}')
+ raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
+
+ if self._result_schema and tool.name in self._result_schema.tools:
+ raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}')
self._function_tools[tool.name] = tool
diff --git a/pydantic_ai_slim/pydantic_ai/dependencies.py b/pydantic_ai_slim/pydantic_ai/dependencies.py
deleted file mode 100644
index 1f0d300689..0000000000
--- a/pydantic_ai_slim/pydantic_ai/dependencies.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from __future__ import annotations as _annotations
-
-from collections.abc import Awaitable, Mapping, Sequence
-from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
-
-from typing_extensions import Concatenate, ParamSpec, TypeAlias
-
-if TYPE_CHECKING:
- from .result import ResultData
-else:
- ResultData = Any
-
-__all__ = (
- 'AgentDeps',
- 'RunContext',
- 'ResultValidatorFunc',
- 'SystemPromptFunc',
- 'ToolReturnValue',
- 'ToolContextFunc',
- 'ToolPlainFunc',
- 'ToolParams',
- 'JsonData',
-)
-
-AgentDeps = TypeVar('AgentDeps')
-"""Type variable for agent dependencies."""
-
-
-@dataclass
-class RunContext(Generic[AgentDeps]):
- """Information about the current call."""
-
- deps: AgentDeps
- """Dependencies for the agent."""
- retry: int
- """Number of retries so far."""
- tool_name: str | None
- """Name of the tool being called."""
-
-
-ToolParams = ParamSpec('ToolParams')
-"""Retrieval function param spec."""
-
-SystemPromptFunc = Union[
- Callable[[RunContext[AgentDeps]], str],
- Callable[[RunContext[AgentDeps]], Awaitable[str]],
- Callable[[], str],
- Callable[[], Awaitable[str]],
-]
-"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
-
-Usage `SystemPromptFunc[AgentDeps]`.
-"""
-
-ResultValidatorFunc = Union[
- Callable[[RunContext[AgentDeps], ResultData], ResultData],
- Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
- Callable[[ResultData], ResultData],
- Callable[[ResultData], Awaitable[ResultData]],
-]
-"""
-A function that always takes `ResultData` and returns `ResultData`,
-but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
-
-Usage `ResultValidator[AgentDeps, ResultData]`.
-"""
-
-JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
-"""Type representing any JSON data."""
-
-ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
-"""Return value of a tool function."""
-ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
-"""A tool function that takes `RunContext` as the first argument.
-
-Usage `ToolContextFunc[AgentDeps, ToolParams]`.
-"""
-ToolPlainFunc = Callable[ToolParams, ToolReturnValue]
-"""A tool function that does not take `RunContext` as the first argument.
-
-Usage `ToolPlainFunc[ToolParams]`.
-"""
diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py
index 19e615862f..9ab0116502 100644
--- a/pydantic_ai_slim/pydantic_ai/models/__init__.py
+++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py
@@ -259,20 +259,30 @@ def infer_model(model: Model | KnownModelName) -> Model:
class AbstractToolDefinition(Protocol):
"""Abstract definition of a function/tool.
- This is used for both tools and result tools.
+ This is used for both function tools and result tools.
"""
- name: str
- """The name of the tool."""
- description: str
- """The description of the tool."""
- json_schema: ObjectJsonSchema
- """The JSON schema for the tool's arguments."""
- outer_typed_dict_key: str | None
- """The key in the outer [TypedDict] that wraps a result tool.
+ @property
+ def name(self) -> str:
+ """The name of the tool."""
+ ...
- This will only be set for result tools which don't have an `object` JSON schema.
- """
+ @property
+ def description(self) -> str:
+ """The description of the tool."""
+ ...
+
+ @property
+ def json_schema(self) -> ObjectJsonSchema:
+ """The JSON schema for the tool's arguments."""
+ ...
+
+ @property
+ def outer_typed_dict_key(self) -> str | None:
+ """The key in the outer [TypedDict] that wraps a result tool.
+
+ This will only be set for result tools which don't have an `object` JSON schema.
+ """
@cache
diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py
index c4165f2c85..06c5c4016f 100644
--- a/pydantic_ai_slim/pydantic_ai/result.py
+++ b/pydantic_ai_slim/pydantic_ai/result.py
@@ -9,7 +9,7 @@
import logfire_api
from . import _result, _utils, exceptions, messages, models
-from .dependencies import AgentDeps
+from .tools import AgentDeps
__all__ = (
'ResultData',
diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py
new file mode 100644
index 0000000000..61995c3ad6
--- /dev/null
+++ b/pydantic_ai_slim/pydantic_ai/tools.py
@@ -0,0 +1,247 @@
+from __future__ import annotations as _annotations
+
+import inspect
+from collections.abc import Awaitable, Mapping, Sequence
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
+
+from pydantic import ValidationError
+from pydantic_core import SchemaValidator
+from typing_extensions import Concatenate, ParamSpec, TypeAlias, final
+
+from . import _pydantic, _utils, messages
+from .exceptions import ModelRetry, UnexpectedModelBehavior
+
+if TYPE_CHECKING:
+ from .result import ResultData
+else:
+ ResultData = Any
+
+
+__all__ = (
+ 'AgentDeps',
+ 'RunContext',
+ 'ResultValidatorFunc',
+ 'SystemPromptFunc',
+ 'ToolReturnValue',
+ 'ToolFuncContext',
+ 'ToolFuncPlain',
+ 'ToolFuncEither',
+ 'ToolParams',
+ 'JsonData',
+ 'Tool',
+)
+
+AgentDeps = TypeVar('AgentDeps')
+"""Type variable for agent dependencies."""
+
+
+@dataclass
+class RunContext(Generic[AgentDeps]):
+ """Information about the current call."""
+
+ deps: AgentDeps
+ """Dependencies for the agent."""
+ retry: int
+ """Number of retries so far."""
+ tool_name: str | None
+ """Name of the tool being called."""
+
+
+ToolParams = ParamSpec('ToolParams')
+"""Retrieval function param spec."""
+
+SystemPromptFunc = Union[
+ Callable[[RunContext[AgentDeps]], str],
+ Callable[[RunContext[AgentDeps]], Awaitable[str]],
+ Callable[[], str],
+ Callable[[], Awaitable[str]],
+]
+"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.
+
+Usage `SystemPromptFunc[AgentDeps]`.
+"""
+
+ResultValidatorFunc = Union[
+ Callable[[RunContext[AgentDeps], ResultData], ResultData],
+ Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]],
+ Callable[[ResultData], ResultData],
+ Callable[[ResultData], Awaitable[ResultData]],
+]
+"""
+A function that always takes `ResultData` and returns `ResultData`,
+but may or maybe not take `CallInfo` as a first argument, and may or may not be async.
+
+Usage `ResultValidator[AgentDeps, ResultData]`.
+"""
+
+JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
+"""Type representing any JSON data."""
+
+ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
+"""Return value of a tool function."""
+ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
+"""A tool function that takes `RunContext` as the first argument.
+
+Usage `ToolContextFunc[AgentDeps, ToolParams]`.
+"""
+ToolFuncPlain = Callable[ToolParams, ToolReturnValue]
+"""A tool function that does not take `RunContext` as the first argument.
+
+Usage `ToolPlainFunc[ToolParams]`.
+"""
+ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
+"""Either kind of tool function.
+
+This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
+[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
+
+Usage `ToolFuncEither[AgentDeps, ToolParams]`.
+"""
+
+A = TypeVar('A')
+
+
+@final
+@dataclass(init=False)
+class Tool(Generic[AgentDeps]):
+ """A tool function for an agent."""
+
+ function: ToolFuncEither[AgentDeps, ...]
+ takes_ctx: bool
+ max_retries: int | None
+ name: str
+ description: str
+ _is_async: bool = field(init=False)
+ _single_arg_name: str | None = field(init=False)
+ _positional_fields: list[str] = field(init=False)
+ _var_positional_field: str | None = field(init=False)
+ _validator: SchemaValidator = field(init=False, repr=False)
+ _json_schema: _utils.ObjectJsonSchema = field(init=False)
+ _current_retry: int = field(default=0, init=False)
+
+ def __init__(
+ self,
+ function: ToolFuncEither[AgentDeps, ...],
+ takes_ctx: bool,
+ *,
+ max_retries: int | None = None,
+ name: str | None = None,
+ description: str | None = None,
+ ):
+ """Create a new tool instance.
+
+ Example usage:
+
+ ```py
+ from pydantic_ai import Agent, RunContext, Tool
+
+ async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
+ return f'{ctx.deps} {x} {y}'
+
+ agent = Agent('test', tools=[Tool(my_tool, True)])
+ ```
+
+ Args:
+ function: The Python function to call as the tool.
+ takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
+ max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
+ name: Name of the tool, inferred from the function if left blank.
+ description: Description of the tool, inferred from the function if left blank.
+ """
+ f = _pydantic.function_schema(function, takes_ctx)
+ self.function = function
+ self.takes_ctx = takes_ctx
+ self.max_retries = max_retries
+ self.name = name or function.__name__
+ self.description = description or f['description']
+ self._is_async = inspect.iscoroutinefunction(self.function)
+ self._single_arg_name = f['single_arg_name']
+ self._positional_fields = f['positional_fields']
+ self._var_positional_field = f['var_positional_field']
+ self._validator = f['validator']
+ self._json_schema = f['json_schema']
+
+ @staticmethod
+ def infer(function: ToolFuncEither[A, ...] | Tool[A]) -> Tool[A]:
+ """Create a tool from a pure function, inferring whether it takes `RunContext` as its first argument.
+
+ Args:
+ function: The tool function to wrap; or for convenience, a `Tool` instance.
+
+ Returns:
+ A new `Tool` instance.
+ """
+ if isinstance(function, Tool):
+ return function
+ else:
+ return Tool(function, takes_ctx=_pydantic.takes_ctx(function))
+
+ def reset(self) -> None:
+ """Reset the current retry count."""
+ self._current_retry = 0
+
+ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message:
+ """Run the tool function asynchronously."""
+ try:
+ if isinstance(message.args, messages.ArgsJson):
+ args_dict = self._validator.validate_json(message.args.args_json)
+ else:
+ args_dict = self._validator.validate_python(message.args.args_dict)
+ except ValidationError as e:
+ return self._on_error(e, message)
+
+ args, kwargs = self._call_args(deps, args_dict, message)
+ try:
+ if self._is_async:
+ function = cast(Callable[[Any], Awaitable[str]], self.function)
+ response_content = await function(*args, **kwargs)
+ else:
+ function = cast(Callable[[Any], str], self.function)
+ response_content = await _utils.run_in_executor(function, *args, **kwargs)
+ except ModelRetry as e:
+ return self._on_error(e, message)
+
+ self._current_retry = 0
+ return messages.ToolReturn(
+ tool_name=message.tool_name,
+ content=response_content,
+ tool_id=message.tool_id,
+ )
+
+ @property
+ def json_schema(self) -> _utils.ObjectJsonSchema:
+ return self._json_schema
+
+ @property
+ def outer_typed_dict_key(self) -> str | None:
+ return None
+
+ def _call_args(
+ self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall
+ ) -> tuple[list[Any], dict[str, Any]]:
+ if self._single_arg_name:
+ args_dict = {self._single_arg_name: args_dict}
+
+ args = [RunContext(deps, self._current_retry, message.tool_name)] if self.takes_ctx else []
+ for positional_field in self._positional_fields:
+ args.append(args_dict.pop(positional_field))
+ if self._var_positional_field:
+ args.extend(args_dict.pop(self._var_positional_field))
+
+ return args, args_dict
+
+ def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt:
+ self._current_retry += 1
+ if self.max_retries is None or self._current_retry > self.max_retries:
+ raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc
+ else:
+ if isinstance(exc, ValidationError):
+ content = exc.errors(include_url=False)
+ else:
+ content = exc.message
+ return messages.RetryPrompt(
+ tool_name=call_message.tool_name,
+ content=content,
+ tool_id=call_message.tool_id,
+ )
diff --git a/tests/test_examples.py b/tests/test_examples.py
index b7856f1c63..3f2c880dc5 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -98,6 +98,9 @@ def test_docs_examples(
# waiting for https://github.com/pydantic/pytest-examples/issues/43
if 'import DatabaseConn' in example.source:
ruff_ignore.append('I001')
+ elif 'async def my_tool(' in example.source:
+ # not sure what's going on here, just ignore it for now
+ ruff_ignore.append('I001')
line_length = 88
if opt_title in ('streamed_hello_world.py', 'streamed_user_profile.py'):
@@ -231,7 +234,7 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo
elif m.role == 'tool-return' and m.tool_name == 'roll_die':
return ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_name', args=ArgsDict({}))])
elif m.role == 'tool-return' and m.tool_name == 'get_player_name':
- return ModelTextResponse(content="Congratulations Adam, you guessed correctly! You're a winner!")
+ return ModelTextResponse(content="Congratulations Anne, you guessed correctly! You're a winner!")
if m.role == 'retry-prompt' and isinstance(m.content, str) and m.content.startswith("No user found with name 'Joh"):
return ModelStructuredResponse(
calls=[ToolCall(tool_name='get_user_by_name', args=ArgsDict({'name': 'John Doe'}))]
diff --git a/tests/test_retrievers.py b/tests/test_tools.py
similarity index 71%
rename from tests/test_retrievers.py
rename to tests/test_tools.py
index 32191d27ec..640b8b1ef9 100644
--- a/tests/test_retrievers.py
+++ b/tests/test_tools.py
@@ -5,7 +5,7 @@
from inline_snapshot import snapshot
from pydantic import BaseModel, Field
-from pydantic_ai import Agent, RunContext, UserError
+from pydantic_ai import Agent, RunContext, Tool, UserError
from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.test import TestModel
@@ -22,7 +22,7 @@ def invalid_tool(x: int) -> str: # pragma: no cover
assert str(exc_info.value) == snapshot(
'Error generating schema for test_tool_no_ctx..invalid_tool:\n'
- ' First argument must be a RunContext instance when using `.tool`'
+ ' First parameter of tools that take context must be annotated with RunContext[...]'
)
@@ -37,7 +37,7 @@ async def invalid_tool(ctx: RunContext[None]) -> str: # pragma: no cover
assert str(exc_info.value) == snapshot(
'Error generating schema for test_tool_plain_with_ctx..invalid_tool:\n'
- ' RunContext instance can only be used with `.tool`'
+ ' RunContext annotations can only be used with tools that take context'
)
@@ -52,8 +52,8 @@ def invalid_tool(x: int, ctx: RunContext[None]) -> str: # pragma: no cover
assert str(exc_info.value) == snapshot(
'Error generating schema for test_tool_ctx_second..invalid_tool:\n'
- ' First argument must be a RunContext instance when using `.tool`\n'
- ' RunContext instance can only be used as the first argument'
+ ' First parameter of tools that take context must be annotated with RunContext[...]\n'
+ ' RunContext annotations can only be used as the first argument'
)
@@ -273,3 +273,73 @@ def takes_just_model(model: Foo, z: int) -> str:
result = agent.run_sync('', model=TestModel())
assert result.data == snapshot('{"takes_just_model":"0 a 0"}')
+
+
+# pyright: reportPrivateUsage=false
+def test_init_tool_plain(set_event_loop: None):
+ call_args: list[int] = []
+
+ def plain_tool(x: int) -> int:
+ call_args.append(x)
+ return x + 1
+
+ agent = Agent('test', tools=[Tool(plain_tool, False)], retries=7)
+ result = agent.run_sync('foobar')
+ assert result.data == snapshot('{"plain_tool":1}')
+ assert call_args == snapshot([0])
+ assert agent._function_tools['plain_tool'].takes_ctx is False
+ assert agent._function_tools['plain_tool'].max_retries == 7
+
+ agent_infer = Agent('test', tools=[plain_tool], retries=7)
+ result = agent_infer.run_sync('foobar')
+ assert result.data == snapshot('{"plain_tool":1}')
+ assert call_args == snapshot([0, 0])
+ assert agent_infer._function_tools['plain_tool'].takes_ctx is False
+ assert agent_infer._function_tools['plain_tool'].max_retries == 7
+
+
+def ctx_tool(ctx: RunContext[int], x: int) -> int:
+ return x + ctx.deps
+
+
+# pyright: reportPrivateUsage=false
+def test_init_tool_ctx(set_event_loop: None):
+ agent = Agent('test', tools=[Tool(ctx_tool, True, max_retries=3)], deps_type=int, retries=7)
+ result = agent.run_sync('foobar', deps=5)
+ assert result.data == snapshot('{"ctx_tool":5}')
+ assert agent._function_tools['ctx_tool'].takes_ctx is True
+ assert agent._function_tools['ctx_tool'].max_retries == 3
+
+ agent_infer = Agent('test', tools=[ctx_tool], deps_type=int)
+ result = agent_infer.run_sync('foobar', deps=6)
+ assert result.data == snapshot('{"ctx_tool":6}')
+ assert agent_infer._function_tools['ctx_tool'].takes_ctx is True
+
+
+def test_repeat_tool():
+ with pytest.raises(UserError, match="Tool name conflicts with existing tool: 'ctx_tool'"):
+ Agent('test', tools=[Tool(ctx_tool, True), ctx_tool], deps_type=int)
+
+
+def test_tool_return_conflict():
+ # this is okay
+ Agent('test', tools=[ctx_tool], deps_type=int)
+ # this is also okay
+ Agent('test', tools=[ctx_tool], deps_type=int, result_type=int)
+ # this raises an error
+ with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"):
+ Agent('test', tools=[ctx_tool], deps_type=int, result_type=int, result_tool_name='ctx_tool')
+
+
+def test_init_ctx_tool_invalid():
+ def plain_tool(x: int) -> int:
+ return x + 1
+
+ m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]'
+ with pytest.raises(UserError, match=m):
+ Tool(plain_tool, True)
+
+
+def test_init_plain_tool_invalid():
+ with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'):
+ Tool(ctx_tool, False)
diff --git a/tests/typed_agent.py b/tests/typed_agent.py
index a73f389b82..19bed49b16 100644
--- a/tests/typed_agent.py
+++ b/tests/typed_agent.py
@@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import Callable, Union, assert_type
-from pydantic_ai import Agent, ModelRetry, RunContext
+from pydantic_ai import Agent, ModelRetry, RunContext, Tool
from pydantic_ai.result import RunResult
@@ -151,3 +151,30 @@ def run_sync3() -> None:
result = union_agent.run_sync('testing')
assert_type(result, RunResult[Union[Foo, Bar]])
assert_type(result.data, Union[Foo, Bar])
+
+
+def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str:
+ return f'{x} {y}'
+
+
+def foobar_plain(x: str, y: int) -> str:
+ return f'{x} {y}'
+
+
+Tool(foobar_ctx, True)
+Tool(foobar_plain, False)
+
+# unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool
+Tool(foobar_ctx, False)
+Tool(foobar_plain, True)
+
+Agent('test', tools=[foobar_ctx], deps_type=int)
+Agent('test', tools=[foobar_plain], deps_type=int)
+Agent('test', tools=[foobar_plain])
+Agent('test', tools=[Tool(foobar_ctx, True)], deps_type=int)
+Agent('test', tools=[Tool(foobar_plain, False)], deps_type=int)
+Agent('test', tools=[Tool(foobar_ctx, True), foobar_ctx, foobar_plain], deps_type=int)
+
+Agent('test', tools=[foobar_ctx], deps_type=str) # pyright: ignore[reportArgumentType]
+Agent('test', tools=[foobar_ctx]) # pyright: ignore[reportArgumentType]
+Agent('test', tools=[Tool(foobar_ctx, True)]) # pyright: ignore[reportArgumentType]