Skip to content

Commit 8d4889c

Browse files
authored
Let FunctionToolset take default values for strict, sequential, requires_approval, metadata (#2909)
1 parent ab0b6f4 commit 8d4889c

File tree

3 files changed

+79
-34
lines changed

3 files changed

+79
-34
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,7 @@ def tool(
10051005
require_parameter_descriptions: bool = False,
10061006
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
10071007
strict: bool | None = None,
1008+
sequential: bool = False,
10081009
requires_approval: bool = False,
10091010
metadata: dict[str, Any] | None = None,
10101011
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
@@ -1021,6 +1022,7 @@ def tool(
10211022
require_parameter_descriptions: bool = False,
10221023
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
10231024
strict: bool | None = None,
1025+
sequential: bool = False,
10241026
requires_approval: bool = False,
10251027
metadata: dict[str, Any] | None = None,
10261028
) -> Any:
@@ -1067,6 +1069,7 @@ async def spam(ctx: RunContext[str], y: float) -> float:
10671069
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
10681070
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
10691071
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
1072+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
10701073
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
10711074
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
10721075
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
@@ -1078,17 +1081,17 @@ def tool_decorator(
10781081
# noinspection PyTypeChecker
10791082
self._function_toolset.add_function(
10801083
func_,
1081-
True,
1082-
name,
1083-
retries,
1084-
prepare,
1085-
docstring_format,
1086-
require_parameter_descriptions,
1087-
schema_generator,
1088-
strict,
1089-
False,
1090-
requires_approval,
1091-
metadata,
1084+
takes_ctx=True,
1085+
name=name,
1086+
retries=retries,
1087+
prepare=prepare,
1088+
docstring_format=docstring_format,
1089+
require_parameter_descriptions=require_parameter_descriptions,
1090+
schema_generator=schema_generator,
1091+
strict=strict,
1092+
sequential=sequential,
1093+
requires_approval=requires_approval,
1094+
metadata=metadata,
10921095
)
10931096
return func_
10941097

@@ -1109,6 +1112,7 @@ def tool_plain(
11091112
require_parameter_descriptions: bool = False,
11101113
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
11111114
strict: bool | None = None,
1115+
sequential: bool = False,
11121116
requires_approval: bool = False,
11131117
metadata: dict[str, Any] | None = None,
11141118
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
@@ -1182,17 +1186,17 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
11821186
# noinspection PyTypeChecker
11831187
self._function_toolset.add_function(
11841188
func_,
1185-
False,
1186-
name,
1187-
retries,
1188-
prepare,
1189-
docstring_format,
1190-
require_parameter_descriptions,
1191-
schema_generator,
1192-
strict,
1193-
sequential,
1194-
requires_approval,
1195-
metadata,
1189+
takes_ctx=False,
1190+
name=name,
1191+
retries=retries,
1192+
prepare=prepare,
1193+
docstring_format=docstring_format,
1194+
require_parameter_descriptions=require_parameter_descriptions,
1195+
schema_generator=schema_generator,
1196+
strict=strict,
1197+
sequential=sequential,
1198+
requires_approval=requires_approval,
1199+
metadata=metadata,
11961200
)
11971201
return func_
11981202

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,31 +45,49 @@ def __init__(
4545
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
4646
*,
4747
max_retries: int = 1,
48-
id: str | None = None,
4948
docstring_format: DocstringFormat = 'auto',
5049
require_parameter_descriptions: bool = False,
5150
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
51+
strict: bool | None = None,
52+
sequential: bool = False,
53+
requires_approval: bool = False,
54+
metadata: dict[str, Any] | None = None,
55+
id: str | None = None,
5256
):
5357
"""Build a new function toolset.
5458
5559
Args:
5660
tools: The tools to add to the toolset.
5761
max_retries: The maximum number of retries for each tool during a run.
58-
id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal,
59-
in which case the ID will be used to identify the toolset's activities within the workflow.
62+
Applies to all tools, unless overridden when adding a tool.
6063
docstring_format: Format of tool docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat].
6164
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
6265
Applies to all tools, unless overridden when adding a tool.
6366
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
6467
Applies to all tools, unless overridden when adding a tool.
6568
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
6669
Applies to all tools, unless overridden when adding a tool.
70+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
71+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
72+
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
73+
Applies to all tools, unless overridden when adding a tool.
74+
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
75+
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
76+
Applies to all tools, unless overridden when adding a tool.
77+
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
78+
Applies to all tools, unless overridden when adding a tool, which will be merged with the toolset's metadata.
79+
id: An optional unique ID for the toolset. A toolset needs to have an ID in order to be used in a durable execution environment like Temporal,
80+
in which case the ID will be used to identify the toolset's activities within the workflow.
6781
"""
6882
self.max_retries = max_retries
6983
self._id = id
7084
self.docstring_format = docstring_format
7185
self.require_parameter_descriptions = require_parameter_descriptions
7286
self.schema_generator = schema_generator
87+
self.strict = strict
88+
self.sequential = sequential
89+
self.requires_approval = requires_approval
90+
self.metadata = metadata
7391

7492
self.tools = {}
7593
for tool in tools:
@@ -97,8 +115,8 @@ def tool(
97115
require_parameter_descriptions: bool | None = None,
98116
schema_generator: type[GenerateJsonSchema] | None = None,
99117
strict: bool | None = None,
100-
sequential: bool = False,
101-
requires_approval: bool = False,
118+
sequential: bool | None = None,
119+
requires_approval: bool | None = None,
102120
metadata: dict[str, Any] | None = None,
103121
) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ...
104122

@@ -114,8 +132,8 @@ def tool(
114132
require_parameter_descriptions: bool | None = None,
115133
schema_generator: type[GenerateJsonSchema] | None = None,
116134
strict: bool | None = None,
117-
sequential: bool = False,
118-
requires_approval: bool = False,
135+
sequential: bool | None = None,
136+
requires_approval: bool | None = None,
119137
metadata: dict[str, Any] | None = None,
120138
) -> Any:
121139
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
@@ -165,10 +183,14 @@ async def spam(ctx: RunContext[str], y: float) -> float:
165183
If `None`, the default value is determined by the toolset.
166184
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
167185
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
186+
If `None`, the default value is determined by the toolset.
168187
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
188+
If `None`, the default value is determined by the toolset.
169189
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
170190
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
191+
If `None`, the default value is determined by the toolset.
171192
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
193+
If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata.
172194
"""
173195

174196
def tool_decorator(
@@ -204,8 +226,8 @@ def add_function(
204226
require_parameter_descriptions: bool | None = None,
205227
schema_generator: type[GenerateJsonSchema] | None = None,
206228
strict: bool | None = None,
207-
sequential: bool = False,
208-
requires_approval: bool = False,
229+
sequential: bool | None = None,
230+
requires_approval: bool | None = None,
209231
metadata: dict[str, Any] | None = None,
210232
) -> None:
211233
"""Add a function as a tool to the toolset.
@@ -232,17 +254,27 @@ def add_function(
232254
If `None`, the default value is determined by the toolset.
233255
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
234256
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
257+
If `None`, the default value is determined by the toolset.
235258
sequential: Whether the function requires a sequential/serial execution environment. Defaults to False.
259+
If `None`, the default value is determined by the toolset.
236260
requires_approval: Whether this tool requires human-in-the-loop approval. Defaults to False.
237261
See the [tools documentation](../deferred-tools.md#human-in-the-loop-tool-approval) for more info.
262+
If `None`, the default value is determined by the toolset.
238263
metadata: Optional metadata for the tool. This is not sent to the model but can be used for filtering and tool behavior customization.
264+
If `None`, the default value is determined by the toolset. If provided, it will be merged with the toolset's metadata.
239265
"""
240266
if docstring_format is None:
241267
docstring_format = self.docstring_format
242268
if require_parameter_descriptions is None:
243269
require_parameter_descriptions = self.require_parameter_descriptions
244270
if schema_generator is None:
245271
schema_generator = self.schema_generator
272+
if strict is None:
273+
strict = self.strict
274+
if sequential is None:
275+
sequential = self.sequential
276+
if requires_approval is None:
277+
requires_approval = self.requires_approval
246278

247279
tool = Tool[AgentDepsT](
248280
func,
@@ -270,6 +302,8 @@ def add_tool(self, tool: Tool[AgentDepsT]) -> None:
270302
raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}')
271303
if tool.max_retries is None:
272304
tool.max_retries = self.max_retries
305+
if self.metadata is not None:
306+
tool.metadata = self.metadata | (tool.metadata or {})
273307
self.tools[tool.name] = tool
274308

275309
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:

tests/test_tools.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,19 +1919,26 @@ def plain_tool(z: int) -> int:
19191919
assert plain_tool_def.metadata == {'type': 'plain'}
19201920

19211921
# Test with FunctionToolset.tool decorator
1922-
toolset = FunctionToolset()
1922+
toolset = FunctionToolset(metadata={'foo': 'bar'})
1923+
1924+
@toolset.tool
1925+
def toolset_plain_tool(a: str) -> str:
1926+
return a.upper() # pragma: no cover
1927+
1928+
toolset_plain_tool_def = toolset.tools['toolset_plain_tool']
1929+
assert toolset_plain_tool_def.metadata == {'foo': 'bar'}
19231930

19241931
@toolset.tool(metadata={'toolset': 'function'})
19251932
def toolset_tool(ctx: RunContext[None], a: str) -> str:
19261933
return a.upper() # pragma: no cover
19271934

19281935
toolset_tool_def = toolset.tools['toolset_tool']
1929-
assert toolset_tool_def.metadata == {'toolset': 'function'}
1936+
assert toolset_tool_def.metadata == {'foo': 'bar', 'toolset': 'function'}
19301937

19311938
# Test with FunctionToolset.add_function
19321939
def standalone_func(ctx: RunContext[None], b: float) -> float:
19331940
return b / 2 # pragma: no cover
19341941

19351942
toolset.add_function(standalone_func, metadata={'method': 'add_function'})
19361943
standalone_tool_def = toolset.tools['standalone_func']
1937-
assert standalone_tool_def.metadata == {'method': 'add_function'}
1944+
assert standalone_tool_def.metadata == {'foo': 'bar', 'method': 'add_function'}

0 commit comments

Comments
 (0)