From 9503fd41241ab53dc7357cde3263c6091d59af5b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 3 Dec 2024 12:08:07 +0000 Subject: [PATCH 1/8] allow adding tools via Agent __init__ --- docs/api/tool.md | 3 + mkdocs.yml | 1 + pydantic_ai_slim/pydantic_ai/__init__.py | 3 +- pydantic_ai_slim/pydantic_ai/_pydantic.py | 12 ++-- pydantic_ai_slim/pydantic_ai/agent.py | 52 +++++++++------- pydantic_ai_slim/pydantic_ai/dependencies.py | 17 +++-- .../pydantic_ai/{_tool.py => tool.py} | 62 +++++++++---------- tests/{test_retrievers.py => test_tools.py} | 24 ++++++- 8 files changed, 108 insertions(+), 66 deletions(-) create mode 100644 docs/api/tool.md rename pydantic_ai_slim/pydantic_ai/{_tool.py => tool.py} (69%) rename tests/{test_retrievers.py => test_tools.py} (91%) diff --git a/docs/api/tool.md b/docs/api/tool.md new file mode 100644 index 0000000000..9e3becc1bb --- /dev/null +++ b/docs/api/tool.md @@ -0,0 +1,3 @@ +# `pydantic_ai.tool` + +::: pydantic_ai.tool diff --git a/mkdocs.yml b/mkdocs.yml index 589b1d2604..ea91b6e0ab 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,6 +31,7 @@ nav: - examples/chat-app.md - API Reference: - api/agent.md + - api/tool.md - api/result.md - api/messages.md - api/dependencies.md diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index c4591d2a17..1a24297e4c 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -3,6 +3,7 @@ from .agent import Agent from .dependencies import RunContext from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError +from .tool import Tool -__all__ = 'Agent', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__' +__all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__' __version__ = version('pydantic_ai_slim') diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py index a948acf051..01641d6db8 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py @@ -6,7 +6,7 @@ from __future__ import annotations as _annotations from inspect import Parameter, signature -from typing import TYPE_CHECKING, Any, TypedDict, cast, get_origin +from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin from pydantic import ConfigDict, TypeAdapter from pydantic._internal import _decorators, _generate_schema, _typing_extra @@ -20,8 +20,7 @@ from ._utils import ObjectJsonSchema, check_object_json_schema, is_model_like if TYPE_CHECKING: - from . import _tool - from .dependencies import AgentDeps, ToolParams + pass __all__ = 'function_schema', 'LazyTypeAdapter' @@ -39,17 +38,16 @@ class FunctionSchema(TypedDict): var_positional_field: str | None -def function_schema(either_function: _tool.ToolEitherFunc[AgentDeps, ToolParams]) -> FunctionSchema: # noqa: C901 +def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSchema: # noqa: C901 """Build a Pydantic validator and JSON schema from a tool function. Args: - either_function: The function to build a validator and JSON schema for. + function: The function to build a validator and JSON schema for. + takes_ctx: Whether the function takes a `RunContext` first argument. Returns: A `FunctionSchema` instance. """ - function = either_function.whichever() - takes_ctx = either_function.is_left() config = ConfigDict(title=function.__name__) config_wrapper = ConfigWrapper(config) gen_schema = _generate_schema.GenerateSchema(config_wrapper) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 13cb4c4959..0b4a868e10 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -12,15 +12,15 @@ from . import ( _result, _system_prompt, - _tool as _r, _utils, exceptions, messages as _messages, models, result, ) -from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc +from .dependencies import AgentDeps, RunContext, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams from .result import ResultData +from .tool import Tool __all__ = ('Agent',) @@ -58,7 +58,7 @@ class Agent(Generic[AgentDeps, ResultData]): _result_validators: list[_result.ResultValidator[AgentDeps, ResultData]] = field(repr=False) _allow_text_result: bool = field(repr=False) _system_prompts: tuple[str, ...] = field(repr=False) - _function_tools: dict[str, _r.Tool[AgentDeps, Any]] = field(repr=False) + _function_tools: dict[str, Tool[AgentDeps]] = field(repr=False) _default_retries: int = field(repr=False) _system_prompt_functions: list[_system_prompt.SystemPromptRunner[AgentDeps]] = field(repr=False) _deps_type: type[AgentDeps] = field(repr=False) @@ -83,6 +83,7 @@ def __init__( result_tool_name: str = 'final_result', result_tool_description: str | None = None, result_retries: int | None = None, + tools: Sequence[Tool[AgentDeps]] = (), defer_model_check: bool = False, ): """Create an agent. @@ -101,6 +102,8 @@ def __init__( result_tool_name: The name of the tool to use for the final result. result_tool_description: The description of the final result tool. result_retries: The maximum number of retries to allow for result validation, defaults to `retries`. + tools: Tools to register with the agent, you can also register tools via the decorators + [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` @@ -119,7 +122,9 @@ def __init__( self._allow_text_result = self._result_schema is None or self._result_schema.allow_text_result self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) - self._function_tools: dict[str, _r.Tool[AgentDeps, Any]] = {} + self._function_tools = {} + for tool in tools: + self._register_tool(tool) self._deps_type = deps_type self._default_retries = retries self._system_prompt_functions = [] @@ -439,16 +444,16 @@ async def result_validator_deps(ctx: RunContext[str], data: str) -> str: return func @overload - def tool(self, func: ToolContextFunc[AgentDeps, ToolParams], /) -> ToolContextFunc[AgentDeps, ToolParams]: ... + def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ... @overload def tool( self, /, *, retries: int | None = None - ) -> Callable[[ToolContextFunc[AgentDeps, ToolParams]], ToolContextFunc[AgentDeps, ToolParams]]: ... + ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ... def tool( self, - func: ToolContextFunc[AgentDeps, ToolParams] | None = None, + func: ToolFuncContext[AgentDeps, ToolParams] | None = None, /, *, retries: int | None = None, @@ -491,27 +496,27 @@ async def spam(ctx: RunContext[str], y: float) -> float: if func is None: def tool_decorator( - func_: ToolContextFunc[AgentDeps, ToolParams], - ) -> ToolContextFunc[AgentDeps, ToolParams]: + func_: ToolFuncContext[AgentDeps, ToolParams], + ) -> ToolFuncContext[AgentDeps, ToolParams]: # noinspection PyTypeChecker - self._register_tool(_utils.Either(left=func_), retries) + self._register_function(func_, True, retries) return func_ return tool_decorator else: # noinspection PyTypeChecker - self._register_tool(_utils.Either(left=func), retries) + self._register_function(func, True, retries) return func @overload - def tool_plain(self, func: ToolPlainFunc[ToolParams], /) -> ToolPlainFunc[ToolParams]: ... + def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ... @overload def tool_plain( self, /, *, retries: int | None = None - ) -> Callable[[ToolPlainFunc[ToolParams]], ToolPlainFunc[ToolParams]]: ... + ) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ... - def tool_plain(self, func: ToolPlainFunc[ToolParams] | None = None, /, *, retries: int | None = None) -> Any: + def tool_plain(self, func: ToolFuncPlain[ToolParams] | None = None, /, *, retries: int | None = None) -> Any: """Decorator to register a tool function which DOES NOT take `RunContext` as an argument. Can decorate a sync or async functions. @@ -548,23 +553,26 @@ async def spam(ctx: RunContext[str]) -> float: """ if func is None: - def tool_decorator( - func_: ToolPlainFunc[ToolParams], - ) -> ToolPlainFunc[ToolParams]: + def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: # noinspection PyTypeChecker - self._register_tool(_utils.Either(right=func_), retries) + self._register_function(func_, False, retries) return func_ return tool_decorator else: - self._register_tool(_utils.Either(right=func), retries) + self._register_function(func, False, retries) return func - def _register_tool(self, func: _r.ToolEitherFunc[AgentDeps, ToolParams], retries: int | None) -> None: - """Private utility to register a tool function.""" + def _register_function( + self, func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None + ) -> None: + """Private utility to register a function as a tool.""" retries_ = retries if retries is not None else self._default_retries - tool = _r.Tool[AgentDeps, ToolParams](func, retries_) + tool = Tool(func, takes_ctx, retries_) + self._register_tool(tool) + def _register_tool(self, tool: Tool[AgentDeps]) -> None: + """Private utility to register a tool instance.""" if self._result_schema and tool.name in self._result_schema.tools: raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}') diff --git a/pydantic_ai_slim/pydantic_ai/dependencies.py b/pydantic_ai_slim/pydantic_ai/dependencies.py index 1f0d300689..e42d64c4cb 100644 --- a/pydantic_ai_slim/pydantic_ai/dependencies.py +++ b/pydantic_ai_slim/pydantic_ai/dependencies.py @@ -17,8 +17,9 @@ 'ResultValidatorFunc', 'SystemPromptFunc', 'ToolReturnValue', - 'ToolContextFunc', - 'ToolPlainFunc', + 'ToolFuncContext', + 'ToolFuncPlain', + 'ToolFuncEither', 'ToolParams', 'JsonData', ) @@ -71,13 +72,21 @@ class RunContext(Generic[AgentDeps]): ToolReturnValue = Union[JsonData, Awaitable[JsonData]] """Return value of a tool function.""" -ToolContextFunc = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue] +ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue] """A tool function that takes `RunContext` as the first argument. Usage `ToolContextFunc[AgentDeps, ToolParams]`. """ -ToolPlainFunc = Callable[ToolParams, ToolReturnValue] +ToolFuncPlain = Callable[ToolParams, ToolReturnValue] """A tool function that does not take `RunContext` as the first argument. Usage `ToolPlainFunc[ToolParams]`. """ +ToolFuncEither = ToolFuncContext[AgentDeps, ToolParams] | ToolFuncPlain[ToolParams] +"""Either kind of tool function. + +This is just a union of [`ToolFuncContext`][pydantic_ai.dependencies.ToolFuncContext] and +[`ToolFuncPlain`][pydantic_ai.dependencies.ToolFuncPlain]. + +Usage `ToolFuncEither[AgentDeps, ToolParams]`. +""" diff --git a/pydantic_ai_slim/pydantic_ai/_tool.py b/pydantic_ai_slim/pydantic_ai/tool.py similarity index 69% rename from pydantic_ai_slim/pydantic_ai/_tool.py rename to pydantic_ai_slim/pydantic_ai/tool.py index 9d82360bde..5146c50adb 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool.py +++ b/pydantic_ai_slim/pydantic_ai/tool.py @@ -9,45 +9,45 @@ from pydantic_core import SchemaValidator from . import _pydantic, _utils, messages -from .dependencies import AgentDeps, RunContext, ToolContextFunc, ToolParams, ToolPlainFunc +from .dependencies import AgentDeps, RunContext, ToolFuncEither from .exceptions import ModelRetry, UnexpectedModelBehavior -# Usage `ToolEitherFunc[AgentDependencies, P]` -ToolEitherFunc = _utils.Either[ToolContextFunc[AgentDeps, ToolParams], ToolPlainFunc[ToolParams]] +__all__ = ('Tool',) -@dataclass(init=False) -class Tool(Generic[AgentDeps, ToolParams]): +@dataclass +class Tool(Generic[AgentDeps]): """A tool function for an agent.""" - name: str - description: str - function: ToolEitherFunc[AgentDeps, ToolParams] = field(repr=False) - is_async: bool - single_arg_name: str | None - positional_fields: list[str] - var_positional_field: str | None - validator: SchemaValidator = field(repr=False) - json_schema: _utils.ObjectJsonSchema - max_retries: int - _current_retry: int = 0 - outer_typed_dict_key: str | None = None - - def __init__(self, function: ToolEitherFunc[AgentDeps, ToolParams], retries: int): - """Build a Tool dataclass from a function.""" - self.function = function - # noinspection PyTypeChecker - f = _pydantic.function_schema(function) - raw_function = function.whichever() - self.name = raw_function.__name__ - self.description = f['description'] - self.is_async = inspect.iscoroutinefunction(raw_function) + function: ToolFuncEither[AgentDeps, ...] + """The Python function to call as the tool.""" + takes_ctx: bool + """Whether the function takes a [`RunContext`][pydantic_ai.dependencies.RunContext] first argument.""" + max_retries: int = 1 + """Maximum number of retries allowed for this tool.""" + name: str = '' + """Name of the tool, inferred from the function if left blank.""" + description: str = '' + """Description of the tool, inferred from the function if left blank.""" + is_async: bool = field(init=False) + single_arg_name: str | None = field(init=False) + positional_fields: list[str] = field(init=False) + var_positional_field: str | None = field(init=False) + validator: SchemaValidator = field(init=False, repr=False) + json_schema: _utils.ObjectJsonSchema = field(init=False) + _current_retry: int = field(default=0, init=False) + outer_typed_dict_key: str | None = field(default=None, init=False) + + def __post_init__(self): + f = _pydantic.function_schema(self.function, self.takes_ctx) + self.name = self.name or self.function.__name__ + self.description = self.description or f['description'] + self.is_async = inspect.iscoroutinefunction(self.function) self.single_arg_name = f['single_arg_name'] self.positional_fields = f['positional_fields'] self.var_positional_field = f['var_positional_field'] self.validator = f['validator'] self.json_schema = f['json_schema'] - self.max_retries = retries def reset(self) -> None: """Reset the current retry count.""" @@ -66,10 +66,10 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes args, kwargs = self._call_args(deps, args_dict, message) try: if self.is_async: - function = cast(Callable[[Any], Awaitable[str]], self.function.whichever()) + function = cast(Callable[[Any], Awaitable[str]], self.function) response_content = await function(*args, **kwargs) else: - function = cast(Callable[[Any], str], self.function.whichever()) + function = cast(Callable[[Any], str], self.function) response_content = await _utils.run_in_executor(function, *args, **kwargs) except ModelRetry as e: return self._on_error(e, message) @@ -87,7 +87,7 @@ def _call_args( if self.single_arg_name: args_dict = {self.single_arg_name: args_dict} - args = [RunContext(deps, self._current_retry, message.tool_name)] if self.function.is_left() else [] + args = [RunContext(deps, self._current_retry, message.tool_name)] if self.takes_ctx else [] for positional_field in self.positional_fields: args.append(args_dict.pop(positional_field)) if self.var_positional_field: diff --git a/tests/test_retrievers.py b/tests/test_tools.py similarity index 91% rename from tests/test_retrievers.py rename to tests/test_tools.py index 32191d27ec..761294dbac 100644 --- a/tests/test_retrievers.py +++ b/tests/test_tools.py @@ -5,7 +5,7 @@ from inline_snapshot import snapshot from pydantic import BaseModel, Field -from pydantic_ai import Agent, RunContext, UserError +from pydantic_ai import Agent, RunContext, Tool, UserError from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel @@ -273,3 +273,25 @@ def takes_just_model(model: Foo, z: int) -> str: result = agent.run_sync('', model=TestModel()) assert result.data == snapshot('{"takes_just_model":"0 a 0"}') + + +def test_init_tool_plain(set_event_loop: None): + call_args: list[int] = [] + + def plain_tool(x: int) -> int: + call_args.append(x) + return x + 1 + + agent = Agent('test', tools=[Tool(plain_tool, False)]) + result = agent.run_sync('foobar') + assert result.data == snapshot('{"plain_tool":1}') + assert call_args == snapshot([0]) + + +def test_init_tool_ctx(set_event_loop: None): + def ctx_tool(ctx: RunContext[int], x: int) -> int: + return x + ctx.deps + + agent = Agent('test', tools=[Tool(ctx_tool, True)], deps_type=int) + result = agent.run_sync('foobar', deps=5) + assert result.data == snapshot('{"ctx_tool":5}') From 3d451542744d47f5fa1780bc19b52d2cc212d40b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 3 Dec 2024 18:20:30 +0000 Subject: [PATCH 2/8] make more dc fields priviate --- .../pydantic_ai/models/__init__.py | 32 +++++++----- pydantic_ai_slim/pydantic_ai/tool.py | 49 +++++++++++-------- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 19e615862f..9ab0116502 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -259,20 +259,30 @@ def infer_model(model: Model | KnownModelName) -> Model: class AbstractToolDefinition(Protocol): """Abstract definition of a function/tool. - This is used for both tools and result tools. + This is used for both function tools and result tools. """ - name: str - """The name of the tool.""" - description: str - """The description of the tool.""" - json_schema: ObjectJsonSchema - """The JSON schema for the tool's arguments.""" - outer_typed_dict_key: str | None - """The key in the outer [TypedDict] that wraps a result tool. + @property + def name(self) -> str: + """The name of the tool.""" + ... - This will only be set for result tools which don't have an `object` JSON schema. - """ + @property + def description(self) -> str: + """The description of the tool.""" + ... + + @property + def json_schema(self) -> ObjectJsonSchema: + """The JSON schema for the tool's arguments.""" + ... + + @property + def outer_typed_dict_key(self) -> str | None: + """The key in the outer [TypedDict] that wraps a result tool. + + This will only be set for result tools which don't have an `object` JSON schema. + """ @cache diff --git a/pydantic_ai_slim/pydantic_ai/tool.py b/pydantic_ai_slim/pydantic_ai/tool.py index 5146c50adb..0d72de8619 100644 --- a/pydantic_ai_slim/pydantic_ai/tool.py +++ b/pydantic_ai_slim/pydantic_ai/tool.py @@ -29,25 +29,24 @@ class Tool(Generic[AgentDeps]): """Name of the tool, inferred from the function if left blank.""" description: str = '' """Description of the tool, inferred from the function if left blank.""" - is_async: bool = field(init=False) - single_arg_name: str | None = field(init=False) - positional_fields: list[str] = field(init=False) - var_positional_field: str | None = field(init=False) - validator: SchemaValidator = field(init=False, repr=False) - json_schema: _utils.ObjectJsonSchema = field(init=False) + _is_async: bool = field(init=False) + _single_arg_name: str | None = field(init=False) + _positional_fields: list[str] = field(init=False) + _var_positional_field: str | None = field(init=False) + _validator: SchemaValidator = field(init=False, repr=False) + _json_schema: _utils.ObjectJsonSchema = field(init=False) _current_retry: int = field(default=0, init=False) - outer_typed_dict_key: str | None = field(default=None, init=False) def __post_init__(self): f = _pydantic.function_schema(self.function, self.takes_ctx) self.name = self.name or self.function.__name__ self.description = self.description or f['description'] - self.is_async = inspect.iscoroutinefunction(self.function) - self.single_arg_name = f['single_arg_name'] - self.positional_fields = f['positional_fields'] - self.var_positional_field = f['var_positional_field'] - self.validator = f['validator'] - self.json_schema = f['json_schema'] + self._is_async = inspect.iscoroutinefunction(self.function) + self._single_arg_name = f['single_arg_name'] + self._positional_fields = f['positional_fields'] + self._var_positional_field = f['var_positional_field'] + self._validator = f['validator'] + self._json_schema = f['json_schema'] def reset(self) -> None: """Reset the current retry count.""" @@ -57,15 +56,15 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes """Run the tool function asynchronously.""" try: if isinstance(message.args, messages.ArgsJson): - args_dict = self.validator.validate_json(message.args.args_json) + args_dict = self._validator.validate_json(message.args.args_json) else: - args_dict = self.validator.validate_python(message.args.args_dict) + args_dict = self._validator.validate_python(message.args.args_dict) except ValidationError as e: return self._on_error(e, message) args, kwargs = self._call_args(deps, args_dict, message) try: - if self.is_async: + if self._is_async: function = cast(Callable[[Any], Awaitable[str]], self.function) response_content = await function(*args, **kwargs) else: @@ -81,17 +80,25 @@ async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Mes tool_id=message.tool_id, ) + @property + def json_schema(self) -> _utils.ObjectJsonSchema: + return self._json_schema + + @property + def outer_typed_dict_key(self) -> str | None: + return None + def _call_args( self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall ) -> tuple[list[Any], dict[str, Any]]: - if self.single_arg_name: - args_dict = {self.single_arg_name: args_dict} + if self._single_arg_name: + args_dict = {self._single_arg_name: args_dict} args = [RunContext(deps, self._current_retry, message.tool_name)] if self.takes_ctx else [] - for positional_field in self.positional_fields: + for positional_field in self._positional_fields: args.append(args_dict.pop(positional_field)) - if self.var_positional_field: - args.extend(args_dict.pop(self.var_positional_field)) + if self._var_positional_field: + args.extend(args_dict.pop(self._var_positional_field)) return args, args_dict From 0ff9023610eabc51296dfc12f0df73e39be3c727 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 08:54:52 +0000 Subject: [PATCH 3/8] fix for 3.9 --- docs/install.md | 2 +- pydantic_ai_slim/pydantic_ai/dependencies.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/install.md b/docs/install.md index f22f3dd9fa..9d1b5ad268 100644 --- a/docs/install.md +++ b/docs/install.md @@ -40,7 +40,7 @@ To run the examples, follow instructions in the [examples docs](examples/index.m ## Slim Install -If you know which model you're going to use and want to avoid installing superfluous package, you can use the [`pydantic-ai-slim`](https://pypi.org/project/pydantic-ai-slim/) package. +If you know which model you're going to use and want to avoid installing superfluous packages, you can use the [`pydantic-ai-slim`](https://pypi.org/project/pydantic-ai-slim/) package. If you're using just [`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel], run: diff --git a/pydantic_ai_slim/pydantic_ai/dependencies.py b/pydantic_ai_slim/pydantic_ai/dependencies.py index e42d64c4cb..807f273ed0 100644 --- a/pydantic_ai_slim/pydantic_ai/dependencies.py +++ b/pydantic_ai_slim/pydantic_ai/dependencies.py @@ -82,7 +82,7 @@ class RunContext(Generic[AgentDeps]): Usage `ToolPlainFunc[ToolParams]`. """ -ToolFuncEither = ToolFuncContext[AgentDeps, ToolParams] | ToolFuncPlain[ToolParams] +ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]] """Either kind of tool function. This is just a union of [`ToolFuncContext`][pydantic_ai.dependencies.ToolFuncContext] and From 29258f5a2189fec5d1c98011232dbd96827224fd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 09:03:52 +0000 Subject: [PATCH 4/8] move dependencies.py into tool.py --- pydantic_ai_slim/pydantic_ai/__init__.py | 3 +- pydantic_ai_slim/pydantic_ai/_pydantic.py | 2 +- pydantic_ai_slim/pydantic_ai/_result.py | 2 +- .../pydantic_ai/_system_prompt.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 3 +- pydantic_ai_slim/pydantic_ai/dependencies.py | 92 ------------------ pydantic_ai_slim/pydantic_ai/result.py | 2 +- pydantic_ai_slim/pydantic_ai/tool.py | 93 ++++++++++++++++++- 8 files changed, 95 insertions(+), 104 deletions(-) delete mode 100644 pydantic_ai_slim/pydantic_ai/dependencies.py diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 1a24297e4c..282decd18f 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -1,9 +1,8 @@ from importlib.metadata import version from .agent import Agent -from .dependencies import RunContext from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError -from .tool import Tool +from .tool import RunContext, Tool __all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__' __version__ = version('pydantic_ai_slim') diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py index 01641d6db8..6aaf4f1372 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py @@ -189,7 +189,7 @@ def _build_schema( def _is_call_ctx(annotation: Any) -> bool: - from .dependencies import RunContext + from .tool import RunContext return annotation is RunContext or ( _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py index cbd2880810..1b7e506ec0 100644 --- a/pydantic_ai_slim/pydantic_ai/_result.py +++ b/pydantic_ai_slim/pydantic_ai/_result.py @@ -11,10 +11,10 @@ from typing_extensions import Self, TypeAliasType, TypedDict from . import _utils, messages -from .dependencies import AgentDeps, ResultValidatorFunc, RunContext from .exceptions import ModelRetry from .messages import ModelStructuredResponse, ToolCall from .result import ResultData +from .tool import AgentDeps, ResultValidatorFunc, RunContext @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index 04096c4e02..e7e6ec0b4b 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, cast from . import _utils -from .dependencies import AgentDeps, RunContext, SystemPromptFunc +from .tool import AgentDeps, RunContext, SystemPromptFunc @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 0b4a868e10..8555534f07 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -18,9 +18,8 @@ models, result, ) -from .dependencies import AgentDeps, RunContext, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams from .result import ResultData -from .tool import Tool +from .tool import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams __all__ = ('Agent',) diff --git a/pydantic_ai_slim/pydantic_ai/dependencies.py b/pydantic_ai_slim/pydantic_ai/dependencies.py deleted file mode 100644 index 807f273ed0..0000000000 --- a/pydantic_ai_slim/pydantic_ai/dependencies.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations as _annotations - -from collections.abc import Awaitable, Mapping, Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union - -from typing_extensions import Concatenate, ParamSpec, TypeAlias - -if TYPE_CHECKING: - from .result import ResultData -else: - ResultData = Any - -__all__ = ( - 'AgentDeps', - 'RunContext', - 'ResultValidatorFunc', - 'SystemPromptFunc', - 'ToolReturnValue', - 'ToolFuncContext', - 'ToolFuncPlain', - 'ToolFuncEither', - 'ToolParams', - 'JsonData', -) - -AgentDeps = TypeVar('AgentDeps') -"""Type variable for agent dependencies.""" - - -@dataclass -class RunContext(Generic[AgentDeps]): - """Information about the current call.""" - - deps: AgentDeps - """Dependencies for the agent.""" - retry: int - """Number of retries so far.""" - tool_name: str | None - """Name of the tool being called.""" - - -ToolParams = ParamSpec('ToolParams') -"""Retrieval function param spec.""" - -SystemPromptFunc = Union[ - Callable[[RunContext[AgentDeps]], str], - Callable[[RunContext[AgentDeps]], Awaitable[str]], - Callable[[], str], - Callable[[], Awaitable[str]], -] -"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async. - -Usage `SystemPromptFunc[AgentDeps]`. -""" - -ResultValidatorFunc = Union[ - Callable[[RunContext[AgentDeps], ResultData], ResultData], - Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], - Callable[[ResultData], ResultData], - Callable[[ResultData], Awaitable[ResultData]], -] -""" -A function that always takes `ResultData` and returns `ResultData`, -but may or maybe not take `CallInfo` as a first argument, and may or may not be async. - -Usage `ResultValidator[AgentDeps, ResultData]`. -""" - -JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]' -"""Type representing any JSON data.""" - -ToolReturnValue = Union[JsonData, Awaitable[JsonData]] -"""Return value of a tool function.""" -ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue] -"""A tool function that takes `RunContext` as the first argument. - -Usage `ToolContextFunc[AgentDeps, ToolParams]`. -""" -ToolFuncPlain = Callable[ToolParams, ToolReturnValue] -"""A tool function that does not take `RunContext` as the first argument. - -Usage `ToolPlainFunc[ToolParams]`. -""" -ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]] -"""Either kind of tool function. - -This is just a union of [`ToolFuncContext`][pydantic_ai.dependencies.ToolFuncContext] and -[`ToolFuncPlain`][pydantic_ai.dependencies.ToolFuncPlain]. - -Usage `ToolFuncEither[AgentDeps, ToolParams]`. -""" diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index c4165f2c85..922fc70455 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -9,7 +9,7 @@ import logfire_api from . import _result, _utils, exceptions, messages, models -from .dependencies import AgentDeps +from .tool import AgentDeps __all__ = ( 'ResultData', diff --git a/pydantic_ai_slim/pydantic_ai/tool.py b/pydantic_ai_slim/pydantic_ai/tool.py index 0d72de8619..a878a0c43b 100644 --- a/pydantic_ai_slim/pydantic_ai/tool.py +++ b/pydantic_ai_slim/pydantic_ai/tool.py @@ -1,18 +1,103 @@ from __future__ import annotations as _annotations import inspect -from collections.abc import Awaitable +from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass, field -from typing import Any, Callable, Generic, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast from pydantic import ValidationError from pydantic_core import SchemaValidator +from typing_extensions import Concatenate, ParamSpec, TypeAlias from . import _pydantic, _utils, messages -from .dependencies import AgentDeps, RunContext, ToolFuncEither from .exceptions import ModelRetry, UnexpectedModelBehavior -__all__ = ('Tool',) +if TYPE_CHECKING: + from .result import ResultData +else: + ResultData = Any + + +__all__ = ( + 'AgentDeps', + 'RunContext', + 'ResultValidatorFunc', + 'SystemPromptFunc', + 'ToolReturnValue', + 'ToolFuncContext', + 'ToolFuncPlain', + 'ToolFuncEither', + 'ToolParams', + 'JsonData', + 'Tool', +) + +AgentDeps = TypeVar('AgentDeps') +"""Type variable for agent dependencies.""" + + +@dataclass +class RunContext(Generic[AgentDeps]): + """Information about the current call.""" + + deps: AgentDeps + """Dependencies for the agent.""" + retry: int + """Number of retries so far.""" + tool_name: str | None + """Name of the tool being called.""" + + +ToolParams = ParamSpec('ToolParams') +"""Retrieval function param spec.""" + +SystemPromptFunc = Union[ + Callable[[RunContext[AgentDeps]], str], + Callable[[RunContext[AgentDeps]], Awaitable[str]], + Callable[[], str], + Callable[[], Awaitable[str]], +] +"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async. + +Usage `SystemPromptFunc[AgentDeps]`. +""" + +ResultValidatorFunc = Union[ + Callable[[RunContext[AgentDeps], ResultData], ResultData], + Callable[[RunContext[AgentDeps], ResultData], Awaitable[ResultData]], + Callable[[ResultData], ResultData], + Callable[[ResultData], Awaitable[ResultData]], +] +""" +A function that always takes `ResultData` and returns `ResultData`, +but may or maybe not take `CallInfo` as a first argument, and may or may not be async. + +Usage `ResultValidator[AgentDeps, ResultData]`. +""" + +JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]' +"""Type representing any JSON data.""" + +ToolReturnValue = Union[JsonData, Awaitable[JsonData]] +"""Return value of a tool function.""" +ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue] +"""A tool function that takes `RunContext` as the first argument. + +Usage `ToolContextFunc[AgentDeps, ToolParams]`. +""" +ToolFuncPlain = Callable[ToolParams, ToolReturnValue] +"""A tool function that does not take `RunContext` as the first argument. + +Usage `ToolPlainFunc[ToolParams]`. +""" +ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]] +"""Either kind of tool function. + +This is just a union of [`ToolFuncContext`][pydantic_ai.dependencies.ToolFuncContext] and +[`ToolFuncPlain`][pydantic_ai.dependencies.ToolFuncPlain]. + +Usage `ToolFuncEither[AgentDeps, ToolParams]`. +""" @dataclass From 0f14c3905b5480e2dfca96db9031ae8884c31558 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 09:04:42 +0000 Subject: [PATCH 5/8] rename tool.py -> tools.py --- pydantic_ai_slim/pydantic_ai/__init__.py | 2 +- pydantic_ai_slim/pydantic_ai/_pydantic.py | 2 +- pydantic_ai_slim/pydantic_ai/_result.py | 2 +- pydantic_ai_slim/pydantic_ai/_system_prompt.py | 2 +- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- pydantic_ai_slim/pydantic_ai/result.py | 2 +- pydantic_ai_slim/pydantic_ai/{tool.py => tools.py} | 0 7 files changed, 6 insertions(+), 6 deletions(-) rename pydantic_ai_slim/pydantic_ai/{tool.py => tools.py} (100%) diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 282decd18f..119699afc6 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -2,7 +2,7 @@ from .agent import Agent from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError -from .tool import RunContext, Tool +from .tools import RunContext, Tool __all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__' __version__ = version('pydantic_ai_slim') diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py index 6aaf4f1372..8b7486084d 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py @@ -189,7 +189,7 @@ def _build_schema( def _is_call_ctx(annotation: Any) -> bool: - from .tool import RunContext + from .tools import RunContext return annotation is RunContext or ( _typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext diff --git a/pydantic_ai_slim/pydantic_ai/_result.py b/pydantic_ai_slim/pydantic_ai/_result.py index 1b7e506ec0..3622f8664e 100644 --- a/pydantic_ai_slim/pydantic_ai/_result.py +++ b/pydantic_ai_slim/pydantic_ai/_result.py @@ -14,7 +14,7 @@ from .exceptions import ModelRetry from .messages import ModelStructuredResponse, ToolCall from .result import ResultData -from .tool import AgentDeps, ResultValidatorFunc, RunContext +from .tools import AgentDeps, ResultValidatorFunc, RunContext @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/_system_prompt.py b/pydantic_ai_slim/pydantic_ai/_system_prompt.py index e7e6ec0b4b..7f05079a13 100644 --- a/pydantic_ai_slim/pydantic_ai/_system_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/_system_prompt.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, cast from . import _utils -from .tool import AgentDeps, RunContext, SystemPromptFunc +from .tools import AgentDeps, RunContext, SystemPromptFunc @dataclass diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 8555534f07..b8da6260fe 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -19,7 +19,7 @@ result, ) from .result import ResultData -from .tool import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams +from .tools import AgentDeps, RunContext, Tool, ToolFuncContext, ToolFuncEither, ToolFuncPlain, ToolParams __all__ = ('Agent',) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 922fc70455..06c5c4016f 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -9,7 +9,7 @@ import logfire_api from . import _result, _utils, exceptions, messages, models -from .tool import AgentDeps +from .tools import AgentDeps __all__ = ( 'ResultData', diff --git a/pydantic_ai_slim/pydantic_ai/tool.py b/pydantic_ai_slim/pydantic_ai/tools.py similarity index 100% rename from pydantic_ai_slim/pydantic_ai/tool.py rename to pydantic_ai_slim/pydantic_ai/tools.py From 9ca65fd0875054692637ab6b5b8447126252ea7a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 10:05:29 +0000 Subject: [PATCH 6/8] more tests --- docs/agents.md | 12 ++-- docs/api/dependencies.md | 3 - docs/api/tool.md | 3 - docs/api/tools.md | 3 + docs/dependencies.md | 10 +-- docs/index.md | 4 +- mkdocs.yml | 3 +- pydantic_ai_slim/pydantic_ai/_pydantic.py | 18 ++++++ pydantic_ai_slim/pydantic_ai/agent.py | 26 ++++---- pydantic_ai_slim/pydantic_ai/tools.py | 79 +++++++++++++++++------ tests/test_examples.py | 3 + tests/test_tools.py | 44 +++++++++++-- tests/typed_agent.py | 28 +++++++- 13 files changed, 180 insertions(+), 56 deletions(-) delete mode 100644 docs/api/dependencies.md delete mode 100644 docs/api/tool.md create mode 100644 docs/api/tools.md diff --git a/docs/agents.md b/docs/agents.md index 5564cbc2c0..111be57a26 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -49,7 +49,7 @@ print(result.data) ``` 1. Create an agent, which expects an integer dependency and returns a boolean result. This agent will have type `#!python Agent[int, bool]`. -2. Define a tool that checks if the square is a winner. Here [`RunContext`][pydantic_ai.dependencies.RunContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error. +2. Define a tool that checks if the square is a winner. Here [`RunContext`][pydantic_ai.tools.RunContext] is parameterized with the dependency type `int`; if you got the dependency type wrong you'd get a typing error. 3. In reality, you might want to use a random number here e.g. `random.randint(0, 36)`. 4. `result.data` will be a boolean indicating if the square is a winner. Pydantic performs the result validation, it'll be typed as a `bool` since its type is derived from the `result_type` generic parameter of the agent. @@ -222,7 +222,7 @@ print(result.data) 1. The agent expects a string dependency. 2. Static system prompt defined at agent creation time. -3. Dynamic system prompt defined via a decorator with [`RunContext`][pydantic_ai.dependencies.RunContext], this is called just after `run_sync`, not when the agent is created, so can benefit from runtime information like the dependencies used on that run. +3. Dynamic system prompt defined via a decorator with [`RunContext`][pydantic_ai.tools.RunContext], this is called just after `run_sync`, not when the agent is created, so can benefit from runtime information like the dependencies used on that run. 4. Another dynamic system prompt, system prompts don't have to have the `RunContext` parameter. _(This example is complete, it can be run "as is")_ @@ -240,8 +240,8 @@ They're useful when it is impractical or impossible to put all the context an ag There are two different decorator functions to register tools: -1. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that need access to the agent [context][pydantic_ai.dependencies.RunContext] -2. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that do not need access to the agent [context][pydantic_ai.dependencies.RunContext] +1. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that need access to the agent [context][pydantic_ai.tools.RunContext] +2. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext] `@agent.tool` is the default since in the majority of cases tools will need access to the agent context. @@ -445,7 +445,7 @@ agent.run_sync('hello', model=FunctionModel(print_schema)) _(This example is complete, it can be run "as is")_ -The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.dependencies.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. +The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.tools.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example) @@ -456,7 +456,7 @@ Validation errors from both function tool parameter validation and [structured r You can also raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] from within a [tool](#function-tools) or [result validator function](results.md#result-validators-functions) to tell the model it should retry generating a response. - The default retry count is **1** but can be altered for the [entire agent][pydantic_ai.Agent.__init__], a [specific tool][pydantic_ai.Agent.tool], or a [result validator][pydantic_ai.Agent.__init__]. -- You can access the current retry count from within a tool or result validator via [`ctx.retry`][pydantic_ai.dependencies.RunContext]. +- You can access the current retry count from within a tool or result validator via [`ctx.retry`][pydantic_ai.tools.RunContext]. Here's an example: diff --git a/docs/api/dependencies.md b/docs/api/dependencies.md deleted file mode 100644 index 9f49436a0a..0000000000 --- a/docs/api/dependencies.md +++ /dev/null @@ -1,3 +0,0 @@ -# `pydantic_ai.dependencies` - -::: pydantic_ai.dependencies diff --git a/docs/api/tool.md b/docs/api/tool.md deleted file mode 100644 index 9e3becc1bb..0000000000 --- a/docs/api/tool.md +++ /dev/null @@ -1,3 +0,0 @@ -# `pydantic_ai.tool` - -::: pydantic_ai.tool diff --git a/docs/api/tools.md b/docs/api/tools.md new file mode 100644 index 0000000000..3dcf88d566 --- /dev/null +++ b/docs/api/tools.md @@ -0,0 +1,3 @@ +# `pydantic_ai.tools` + +::: pydantic_ai.tools diff --git a/docs/dependencies.md b/docs/dependencies.md index 129b58228a..b088ba94bc 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -51,7 +51,7 @@ _(This example is complete, it can be run "as is")_ ## Accessing Dependencies -Dependencies are accessed through the [`RunContext`][pydantic_ai.dependencies.RunContext] type, this should be the first parameter of system prompt functions etc. +Dependencies are accessed through the [`RunContext`][pydantic_ai.tools.RunContext] type, this should be the first parameter of system prompt functions etc. ```py title="system_prompt_dependencies.py" hl_lines="20-27" @@ -92,10 +92,10 @@ async def main(): #> Did you hear about the toothpaste scandal? They called it Colgate. ``` -1. [`RunContext`][pydantic_ai.dependencies.RunContext] may optionally be passed to a [`system_prompt`][pydantic_ai.Agent.system_prompt] function as the only argument. -2. [`RunContext`][pydantic_ai.dependencies.RunContext] is parameterized with the type of the dependencies, if this type is incorrect, static type checkers will raise an error. -3. Access dependencies through the [`.deps`][pydantic_ai.dependencies.RunContext.deps] attribute. -4. Access dependencies through the [`.deps`][pydantic_ai.dependencies.RunContext.deps] attribute. +1. [`RunContext`][pydantic_ai.tools.RunContext] may optionally be passed to a [`system_prompt`][pydantic_ai.Agent.system_prompt] function as the only argument. +2. [`RunContext`][pydantic_ai.tools.RunContext] is parameterized with the type of the dependencies, if this type is incorrect, static type checkers will raise an error. +3. Access dependencies through the [`.deps`][pydantic_ai.tools.RunContext.deps] attribute. +4. Access dependencies through the [`.deps`][pydantic_ai.tools.RunContext.deps] attribute. _(This example is complete, it can be run "as is")_ diff --git a/docs/index.md b/docs/index.md index f6b233b71a..69d7e14f52 100644 --- a/docs/index.md +++ b/docs/index.md @@ -125,8 +125,8 @@ async def main(): 2. Here we configure the agent to use [OpenAI's GPT-4o model](api/models/openai.md), you can also set the model when running the agent. 3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [tool](agents.md#function-tools) functions. PydanticAI's system of dependency injection provides a [type-safe](agents.md#static-type-checking) way to customise the behavior of your agents, and can be especially useful when running [unit tests](testing-evals.md) and evals. 4. Static [system prompts](agents.md#system-prompts) can be registered with the [`system_prompt` keyword argument][pydantic_ai.Agent.__init__] to the agent. -5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`RunContext`][pydantic_ai.dependencies.RunContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it. -6. [`tool`](agents.md#function-tools) let you register functions which the LLM may call while responding to a user. Again, dependencies are carried via [`RunContext`][pydantic_ai.dependencies.RunContext], any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry. +5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`RunContext`][pydantic_ai.tools.RunContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it. +6. [`tool`](agents.md#function-tools) let you register functions which the LLM may call while responding to a user. Again, dependencies are carried via [`RunContext`][pydantic_ai.tools.RunContext], any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry. 7. The docstring of a tool is also passed to the LLM as the description of the tool. Parameter descriptions are [extracted](agents.md#function-tools-and-schema) from the docstring and added to the parameter schema sent to the LLM. 8. [Run the agent](agents.md#running-agents) asynchronously, conducting a conversation with the LLM until a final response is reached. Even in this fairly simple case, the agent will exchange multiple messages with the LLM as tools are called to retrieve a result. 9. The response from the agent will, be guaranteed to be a `SupportResult`, if validation fails [reflection](agents.md#reflection-and-self-correction) will mean the agent is prompted to try again. diff --git a/mkdocs.yml b/mkdocs.yml index ea91b6e0ab..d46b774a93 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -31,10 +31,9 @@ nav: - examples/chat-app.md - API Reference: - api/agent.md - - api/tool.md + - api/tools.md - api/result.md - api/messages.md - - api/dependencies.md - api/exceptions.md - api/models/base.md - api/models/openai.md diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py index 8b7486084d..b4cf35dfba 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py @@ -157,6 +157,24 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc ) +def takes_ctx(function: Callable[..., Any]) -> bool: + """Check if a function takes a `RunContext` first argument. + + Args: + function: The function to check. + + Returns: + `True` if the function takes a `RunContext` as first argument, `False` otherwise. + """ + sig = signature(function) + try: + _, first_param = next(iter(sig.parameters.items())) + except StopIteration: + return False + else: + return first_param.annotation is not sig.empty and _is_call_ctx(first_param.annotation) + + def _build_schema( fields: dict[str, core_schema.TypedDictField], var_kwargs_schema: core_schema.CoreSchema | None, diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b8da6260fe..5f8e2a8d78 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations import asyncio +import dataclasses from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field @@ -33,7 +34,7 @@ class Agent(Generic[AgentDeps, ResultData]): """Class for defining "agents" - a way to have a specific type of "conversation" with an LLM. - Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.dependencies.AgentDeps] + Agents are generic in the dependency type they take [`AgentDeps`][pydantic_ai.tools.AgentDeps] and the result data type they return, [`ResultData`][pydantic_ai.result.ResultData]. By default, if neither generic parameter is customised, agents have type `Agent[None, str]`. @@ -82,7 +83,7 @@ def __init__( result_tool_name: str = 'final_result', result_tool_description: str | None = None, result_retries: int | None = None, - tools: Sequence[Tool[AgentDeps]] = (), + tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (), defer_model_check: bool = False, ): """Create an agent. @@ -122,10 +123,10 @@ def __init__( self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) self._function_tools = {} + self._default_retries = retries for tool in tools: - self._register_tool(tool) + self._register_tool(Tool.infer(tool)) self._deps_type = deps_type - self._default_retries = retries self._system_prompt_functions = [] self._max_result_retries = result_retries if result_retries is not None else retries self._current_result_retry = 0 @@ -359,7 +360,7 @@ def system_prompt( ) -> _system_prompt.SystemPromptFunc[AgentDeps]: """Decorator to register a system prompt function. - Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's only argument. + Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's only argument. Can decorate a sync or async functions. Overloads for every possible signature of `system_prompt` are included so the decorator doesn't obscure @@ -410,7 +411,7 @@ def result_validator( ) -> _result.ResultValidatorFunc[AgentDeps, ResultData]: """Decorator to register a result validator function. - Optionally takes [`RunContext`][pydantic_ai.dependencies.RunContext] as it's first argument. + Optionally takes [`RunContext`][pydantic_ai.tools.RunContext] as it's first argument. Can decorate a sync or async functions. Overloads for every possible signature of `result_validator` are included so the decorator doesn't obscure @@ -458,7 +459,7 @@ def tool( retries: int | None = None, ) -> Any: """Decorator to register a tool function which takes - [`RunContext`][pydantic_ai.dependencies.RunContext] as its first argument. + [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. Can decorate a sync or async functions. @@ -567,16 +568,19 @@ def _register_function( ) -> None: """Private utility to register a function as a tool.""" retries_ = retries if retries is not None else self._default_retries - tool = Tool(func, takes_ctx, retries_) + tool = Tool(func, takes_ctx, max_retries=retries_) self._register_tool(tool) def _register_tool(self, tool: Tool[AgentDeps]) -> None: """Private utility to register a tool instance.""" - if self._result_schema and tool.name in self._result_schema.tools: - raise ValueError(f'Tool name conflicts with result schema name: {tool.name!r}') + if tool.max_retries is None: + tool = dataclasses.replace(tool, max_retries=self._default_retries) if tool.name in self._function_tools: - raise ValueError(f'Tool name conflicts with existing tool: {tool.name!r}') + raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') + + if self._result_schema and tool.name in self._result_schema.tools: + raise exceptions.UserError(f'Tool name conflicts with result schema name: {tool.name!r}') self._function_tools[tool.name] = tool diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index a878a0c43b..61995c3ad6 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from pydantic_core import SchemaValidator -from typing_extensions import Concatenate, ParamSpec, TypeAlias +from typing_extensions import Concatenate, ParamSpec, TypeAlias, final from . import _pydantic, _utils, messages from .exceptions import ModelRetry, UnexpectedModelBehavior @@ -93,27 +93,25 @@ class RunContext(Generic[AgentDeps]): ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]] """Either kind of tool function. -This is just a union of [`ToolFuncContext`][pydantic_ai.dependencies.ToolFuncContext] and -[`ToolFuncPlain`][pydantic_ai.dependencies.ToolFuncPlain]. +This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and +[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain]. Usage `ToolFuncEither[AgentDeps, ToolParams]`. """ +A = TypeVar('A') -@dataclass + +@final +@dataclass(init=False) class Tool(Generic[AgentDeps]): """A tool function for an agent.""" function: ToolFuncEither[AgentDeps, ...] - """The Python function to call as the tool.""" takes_ctx: bool - """Whether the function takes a [`RunContext`][pydantic_ai.dependencies.RunContext] first argument.""" - max_retries: int = 1 - """Maximum number of retries allowed for this tool.""" - name: str = '' - """Name of the tool, inferred from the function if left blank.""" - description: str = '' - """Description of the tool, inferred from the function if left blank.""" + max_retries: int | None + name: str + description: str _is_async: bool = field(init=False) _single_arg_name: str | None = field(init=False) _positional_fields: list[str] = field(init=False) @@ -122,10 +120,41 @@ class Tool(Generic[AgentDeps]): _json_schema: _utils.ObjectJsonSchema = field(init=False) _current_retry: int = field(default=0, init=False) - def __post_init__(self): - f = _pydantic.function_schema(self.function, self.takes_ctx) - self.name = self.name or self.function.__name__ - self.description = self.description or f['description'] + def __init__( + self, + function: ToolFuncEither[AgentDeps, ...], + takes_ctx: bool, + *, + max_retries: int | None = None, + name: str | None = None, + description: str | None = None, + ): + """Create a new tool instance. + + Example usage: + + ```py + from pydantic_ai import Agent, RunContext, Tool + + async def my_tool(ctx: RunContext[int], x: int, y: int) -> str: + return f'{ctx.deps} {x} {y}' + + agent = Agent('test', tools=[Tool(my_tool, True)]) + ``` + + Args: + function: The Python function to call as the tool. + takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument. + max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`. + name: Name of the tool, inferred from the function if left blank. + description: Description of the tool, inferred from the function if left blank. + """ + f = _pydantic.function_schema(function, takes_ctx) + self.function = function + self.takes_ctx = takes_ctx + self.max_retries = max_retries + self.name = name or function.__name__ + self.description = description or f['description'] self._is_async = inspect.iscoroutinefunction(self.function) self._single_arg_name = f['single_arg_name'] self._positional_fields = f['positional_fields'] @@ -133,6 +162,21 @@ def __post_init__(self): self._validator = f['validator'] self._json_schema = f['json_schema'] + @staticmethod + def infer(function: ToolFuncEither[A, ...] | Tool[A]) -> Tool[A]: + """Create a tool from a pure function, inferring whether it takes `RunContext` as its first argument. + + Args: + function: The tool function to wrap; or for convenience, a `Tool` instance. + + Returns: + A new `Tool` instance. + """ + if isinstance(function, Tool): + return function + else: + return Tool(function, takes_ctx=_pydantic.takes_ctx(function)) + def reset(self) -> None: """Reset the current retry count.""" self._current_retry = 0 @@ -189,8 +233,7 @@ def _call_args( def _on_error(self, exc: ValidationError | ModelRetry, call_message: messages.ToolCall) -> messages.RetryPrompt: self._current_retry += 1 - if self._current_retry > self.max_retries: - # TODO custom error with details of the tool + if self.max_retries is None or self._current_retry > self.max_retries: raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc else: if isinstance(exc, ValidationError): diff --git a/tests/test_examples.py b/tests/test_examples.py index b7856f1c63..527e4b03cd 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -98,6 +98,9 @@ def test_docs_examples( # waiting for https://github.com/pydantic/pytest-examples/issues/43 if 'import DatabaseConn' in example.source: ruff_ignore.append('I001') + elif 'async def my_tool(' in example.source: + # not sure what's going on here, just ignore it for now + ruff_ignore.append('I001') line_length = 88 if opt_title in ('streamed_hello_world.py', 'streamed_user_profile.py'): diff --git a/tests/test_tools.py b/tests/test_tools.py index 761294dbac..04edef3833 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -275,6 +275,7 @@ def takes_just_model(model: Foo, z: int) -> str: assert result.data == snapshot('{"takes_just_model":"0 a 0"}') +# pyright: reportPrivateUsage=false def test_init_tool_plain(set_event_loop: None): call_args: list[int] = [] @@ -282,16 +283,49 @@ def plain_tool(x: int) -> int: call_args.append(x) return x + 1 - agent = Agent('test', tools=[Tool(plain_tool, False)]) + agent = Agent('test', tools=[Tool(plain_tool, False)], retries=7) result = agent.run_sync('foobar') assert result.data == snapshot('{"plain_tool":1}') assert call_args == snapshot([0]) + assert agent._function_tools['plain_tool'].takes_ctx is False + assert agent._function_tools['plain_tool'].max_retries == 7 + agent_infer = Agent('test', tools=[plain_tool], retries=7) + result = agent_infer.run_sync('foobar') + assert result.data == snapshot('{"plain_tool":1}') + assert call_args == snapshot([0, 0]) + assert agent_infer._function_tools['plain_tool'].takes_ctx is False + assert agent_infer._function_tools['plain_tool'].max_retries == 7 -def test_init_tool_ctx(set_event_loop: None): - def ctx_tool(ctx: RunContext[int], x: int) -> int: - return x + ctx.deps - agent = Agent('test', tools=[Tool(ctx_tool, True)], deps_type=int) +def ctx_tool(ctx: RunContext[int], x: int) -> int: + return x + ctx.deps + + +# pyright: reportPrivateUsage=false +def test_init_tool_ctx(set_event_loop: None): + agent = Agent('test', tools=[Tool(ctx_tool, True, max_retries=3)], deps_type=int, retries=7) result = agent.run_sync('foobar', deps=5) assert result.data == snapshot('{"ctx_tool":5}') + assert agent._function_tools['ctx_tool'].takes_ctx is True + assert agent._function_tools['ctx_tool'].max_retries == 3 + + agent_infer = Agent('test', tools=[ctx_tool], deps_type=int) + result = agent_infer.run_sync('foobar', deps=6) + assert result.data == snapshot('{"ctx_tool":6}') + assert agent_infer._function_tools['ctx_tool'].takes_ctx is True + + +def test_repeat_tool(): + with pytest.raises(UserError, match="Tool name conflicts with existing tool: 'ctx_tool'"): + Agent('test', tools=[Tool(ctx_tool, True), ctx_tool], deps_type=int) + + +def test_tool_return_conflict(): + # this is okay + Agent('test', tools=[ctx_tool], deps_type=int) + # this is also okay + Agent('test', tools=[ctx_tool], deps_type=int, result_type=int) + # this raises an error + with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"): + Agent('test', tools=[ctx_tool], deps_type=int, result_type=int, result_tool_name='ctx_tool') diff --git a/tests/typed_agent.py b/tests/typed_agent.py index a73f389b82..9546adc46a 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Callable, Union, assert_type -from pydantic_ai import Agent, ModelRetry, RunContext +from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.result import RunResult @@ -151,3 +151,29 @@ def run_sync3() -> None: result = union_agent.run_sync('testing') assert_type(result, RunResult[Union[Foo, Bar]]) assert_type(result.data, Union[Foo, Bar]) + + +def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str: + return f'{x} {y}' + + +def foobar_plain(x: str, y: int) -> str: + return f'{x} {y}' + + +Tool(foobar_ctx, True) +Tool(foobar_plain, False) + +# unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool +Tool(foobar_ctx, False) +Tool(foobar_plain, True) + +Agent('test', tools=[foobar_ctx], deps_type=int) +Agent('test', tools=[foobar_plain], deps_type=int) +Agent('test', tools=[foobar_plain]) +Agent('test', tools=[Tool(foobar_ctx, True)], deps_type=int) +Agent('test', tools=[Tool(foobar_plain, False)], deps_type=int) +Agent('test', tools=[Tool(foobar_ctx, True), foobar_ctx, foobar_plain], deps_type=int) + +Agent('test', tools=[foobar_ctx]) # pyright: ignore[reportArgumentType] +Agent('test', tools=[Tool(foobar_ctx, True)]) # pyright: ignore[reportArgumentType] From eabe6f186eaefc14d8fa0599093614fc8106348a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 10:25:41 +0000 Subject: [PATCH 7/8] add docs for tool kwarg --- docs/agents.md | 66 +++++++++++++++++++---- pydantic_ai_slim/pydantic_ai/_pydantic.py | 6 +-- tests/test_examples.py | 2 +- tests/test_tools.py | 22 ++++++-- tests/typed_agent.py | 1 + 5 files changed, 78 insertions(+), 19 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 111be57a26..00280baed4 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -161,7 +161,7 @@ def foobar(x: bytes) -> None: pass -result = agent.run_sync('Does their name start with "A"?', deps=User('Adam')) +result = agent.run_sync('Does their name start with "A"?', deps=User('Anne')) foobar(result.data) # (3)! ``` @@ -238,12 +238,13 @@ They're useful when it is impractical or impossible to put all the context an ag The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) -There are two different decorator functions to register tools: +There are a number of ways to register tools with an agent: -1. [`@agent.tool`][pydantic_ai.Agent.tool] — for tools that need access to the agent [context][pydantic_ai.tools.RunContext] -2. [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext] +* via the [`@agent.tool`][pydantic_ai.Agent.tool] decorator — for tools that need access to the agent [context][pydantic_ai.tools.RunContext] +* via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext] +* via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool] -`@agent.tool` is the default since in the majority of cases tools will need access to the agent context. +`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent context. Here's an example using both: @@ -275,9 +276,9 @@ def get_player_name(ctx: RunContext[str]) -> str: return ctx.deps -dice_result = agent.run_sync('My guess is 4', deps='Adam') # (5)! +dice_result = agent.run_sync('My guess is 4', deps='Anne') # (5)! print(dice_result.data) -#> Congratulations Adam, you guessed correctly! You're a winner! +#> Congratulations Anne, you guessed correctly! You're a winner! ``` 1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model. @@ -330,13 +331,13 @@ print(dice_result.all_messages()) ), ToolReturn( tool_name='get_player_name', - content='Adam', + content='Anne', tool_id=None, timestamp=datetime.datetime(...), role='tool-return', ), ModelTextResponse( - content="Congratulations Adam, you guessed correctly! You're a winner!", + content="Congratulations Anne, you guessed correctly! You're a winner!", timestamp=datetime.datetime(...), role='model-text-response', ), @@ -370,16 +371,59 @@ sequenceDiagram deactivate LLM activate Agent Note over Agent: Retrieves player name - Agent -->> LLM: ToolReturn
"Adam" + Agent -->> LLM: ToolReturn
"Anne" deactivate Agent activate LLM Note over LLM: LLM constructs final response - LLM ->> Agent: ModelTextResponse
"Congratulations Adam, ..." + LLM ->> Agent: ModelTextResponse
"Congratulations Anne, ..." deactivate LLM Note over Agent: Game session complete ``` +### Registering Function tools via the argument + +As well as using the decorators, we can also register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to re-use tools, and can also give more fine-grained control over the tools. + +```py title="dice_game_tool_kwarg.py" +import random + +from pydantic_ai import Agent, RunContext, Tool + + +def roll_die() -> str: + """Roll a six-sided die and return the result.""" + return str(random.randint(1, 6)) + + +def get_player_name(ctx: RunContext[str]) -> str: + """Get the player's name.""" + return ctx.deps + + +agent_a = Agent( + 'gemini-1.5-flash', + deps_type=str, + tools=[roll_die, get_player_name], # (2)! +) +agent_b = Agent( + 'gemini-1.5-flash', + deps_type=str, + tools=[ # (2)! + Tool(roll_die, takes_ctx=False), + Tool(get_player_name, takes_ctx=True), + ], +) +dice_result = agent_b.run_sync('My guess is 4', deps='Anne') # (5)! +print(dice_result.data) +#> Congratulations Anne, you guessed correctly! You're a winner! +``` + +1. The simplest way to register tools via the `Agent` constructor is to pass a list of functions, the function signature is inspected to determine if the tool needs the agent context. +2. `agent_a` and `agent_b` are identical — but we can use [`Tool`][pydantic_ai.tools.Tool] to give more fine-grained control over how tools are defined, e.g. setting their name or description. + +_(This example is complete, it can be run "as is")_ + ### Function Tools vs. Structured Results As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and return a result. diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_pydantic.py index b4cf35dfba..4aecd5fcf1 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_pydantic.py @@ -76,13 +76,13 @@ def function_schema(function: Callable[..., Any], takes_ctx: bool) -> FunctionSc if index == 0 and takes_ctx: if not _is_call_ctx(annotation): - errors.append('First argument must be a RunContext instance when using `.tool`') + errors.append('First parameter of tools that take context must be annotated with RunContext[...]') continue elif not takes_ctx and _is_call_ctx(annotation): - errors.append('RunContext instance can only be used with `.tool`') + errors.append('RunContext annotations can only be used with tools that take context') continue elif index != 0 and _is_call_ctx(annotation): - errors.append('RunContext instance can only be used as the first argument') + errors.append('RunContext annotations can only be used as the first argument') continue field_name = p.name diff --git a/tests/test_examples.py b/tests/test_examples.py index 527e4b03cd..3f2c880dc5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -234,7 +234,7 @@ async def model_logic(messages: list[Message], info: AgentInfo) -> ModelAnyRespo elif m.role == 'tool-return' and m.tool_name == 'roll_die': return ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_name', args=ArgsDict({}))]) elif m.role == 'tool-return' and m.tool_name == 'get_player_name': - return ModelTextResponse(content="Congratulations Adam, you guessed correctly! You're a winner!") + return ModelTextResponse(content="Congratulations Anne, you guessed correctly! You're a winner!") if m.role == 'retry-prompt' and isinstance(m.content, str) and m.content.startswith("No user found with name 'Joh"): return ModelStructuredResponse( calls=[ToolCall(tool_name='get_user_by_name', args=ArgsDict({'name': 'John Doe'}))] diff --git a/tests/test_tools.py b/tests/test_tools.py index 04edef3833..640b8b1ef9 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -22,7 +22,7 @@ def invalid_tool(x: int) -> str: # pragma: no cover assert str(exc_info.value) == snapshot( 'Error generating schema for test_tool_no_ctx..invalid_tool:\n' - ' First argument must be a RunContext instance when using `.tool`' + ' First parameter of tools that take context must be annotated with RunContext[...]' ) @@ -37,7 +37,7 @@ async def invalid_tool(ctx: RunContext[None]) -> str: # pragma: no cover assert str(exc_info.value) == snapshot( 'Error generating schema for test_tool_plain_with_ctx..invalid_tool:\n' - ' RunContext instance can only be used with `.tool`' + ' RunContext annotations can only be used with tools that take context' ) @@ -52,8 +52,8 @@ def invalid_tool(x: int, ctx: RunContext[None]) -> str: # pragma: no cover assert str(exc_info.value) == snapshot( 'Error generating schema for test_tool_ctx_second..invalid_tool:\n' - ' First argument must be a RunContext instance when using `.tool`\n' - ' RunContext instance can only be used as the first argument' + ' First parameter of tools that take context must be annotated with RunContext[...]\n' + ' RunContext annotations can only be used as the first argument' ) @@ -329,3 +329,17 @@ def test_tool_return_conflict(): # this raises an error with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"): Agent('test', tools=[ctx_tool], deps_type=int, result_type=int, result_tool_name='ctx_tool') + + +def test_init_ctx_tool_invalid(): + def plain_tool(x: int) -> int: + return x + 1 + + m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]' + with pytest.raises(UserError, match=m): + Tool(plain_tool, True) + + +def test_init_plain_tool_invalid(): + with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'): + Tool(ctx_tool, False) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 9546adc46a..19bed49b16 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -175,5 +175,6 @@ def foobar_plain(x: str, y: int) -> str: Agent('test', tools=[Tool(foobar_plain, False)], deps_type=int) Agent('test', tools=[Tool(foobar_ctx, True), foobar_ctx, foobar_plain], deps_type=int) +Agent('test', tools=[foobar_ctx], deps_type=str) # pyright: ignore[reportArgumentType] Agent('test', tools=[foobar_ctx]) # pyright: ignore[reportArgumentType] Agent('test', tools=[Tool(foobar_ctx, True)]) # pyright: ignore[reportArgumentType] From 951dea21f76a2704eb5100ff5b08f69f47c9729d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 4 Dec 2024 10:29:04 +0000 Subject: [PATCH 8/8] docs cleanup --- docs/agents.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 00280baed4..c75869a547 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -381,9 +381,9 @@ sequenceDiagram Note over Agent: Game session complete ``` -### Registering Function tools via the argument +### Registering Function Tools via kwarg -As well as using the decorators, we can also register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to re-use tools, and can also give more fine-grained control over the tools. +As well as using the decorators, we can register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to re-use tools, and can also give more fine-grained control over the tools. ```py title="dice_game_tool_kwarg.py" import random @@ -404,7 +404,7 @@ def get_player_name(ctx: RunContext[str]) -> str: agent_a = Agent( 'gemini-1.5-flash', deps_type=str, - tools=[roll_die, get_player_name], # (2)! + tools=[roll_die, get_player_name], # (1)! ) agent_b = Agent( 'gemini-1.5-flash', @@ -414,12 +414,12 @@ agent_b = Agent( Tool(get_player_name, takes_ctx=True), ], ) -dice_result = agent_b.run_sync('My guess is 4', deps='Anne') # (5)! +dice_result = agent_b.run_sync('My guess is 4', deps='Anne') print(dice_result.data) #> Congratulations Anne, you guessed correctly! You're a winner! ``` -1. The simplest way to register tools via the `Agent` constructor is to pass a list of functions, the function signature is inspected to determine if the tool needs the agent context. +1. The simplest way to register tools via the `Agent` constructor is to pass a list of functions, the function signature is inspected to determine if the tool takes [`RunContext`][pydantic_ai.tools.RunContext]. 2. `agent_a` and `agent_b` are identical — but we can use [`Tool`][pydantic_ai.tools.Tool] to give more fine-grained control over how tools are defined, e.g. setting their name or description. _(This example is complete, it can be run "as is")_