From 6b8080032adb2dbfaafba7b4d3379e639148c8a5 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 2 Jan 2025 18:35:40 +0000 Subject: [PATCH 1/2] add "default" to "AgentDeps" --- pydantic_ai_slim/pydantic_ai/agent.py | 14 +++++++------- pydantic_ai_slim/pydantic_ai/tools.py | 28 +++++++++++++++++---------- tests/typed_agent.py | 9 +++++++-- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b4af5db766..d304094cb7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -125,7 +125,7 @@ def __init__( result_tool_name: str = 'final_result', result_tool_description: str | None = None, result_retries: int | None = None, - tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (), + tools: Sequence[Tool[AgentDeps] | ToolFuncEither[..., AgentDeps]] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', ): @@ -624,7 +624,7 @@ async def result_validator_deps(ctx: RunContext[str], data: str) -> str: return func @overload - def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ... + def tool(self, func: ToolFuncContext[ToolParams, AgentDeps], /) -> ToolFuncContext[ToolParams, AgentDeps]: ... @overload def tool( @@ -633,11 +633,11 @@ def tool( *, retries: int | None = None, prepare: ToolPrepareFunc[AgentDeps] | None = None, - ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ... + ) -> Callable[[ToolFuncContext[ToolParams, AgentDeps]], ToolFuncContext[ToolParams, AgentDeps]]: ... def tool( self, - func: ToolFuncContext[AgentDeps, ToolParams] | None = None, + func: ToolFuncContext[ToolParams, AgentDeps] | None = None, /, *, retries: int | None = None, @@ -683,8 +683,8 @@ async def spam(ctx: RunContext[str], y: float) -> float: if func is None: def tool_decorator( - func_: ToolFuncContext[AgentDeps, ToolParams], - ) -> ToolFuncContext[AgentDeps, ToolParams]: + func_: ToolFuncContext[ToolParams, AgentDeps], + ) -> ToolFuncContext[ToolParams, AgentDeps]: # noinspection PyTypeChecker self._register_function(func_, True, retries, prepare) return func_ @@ -766,7 +766,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams def _register_function( self, - func: ToolFuncEither[AgentDeps, ToolParams], + func: ToolFuncEither[ToolParams, AgentDeps], takes_ctx: bool, retries: int | None, prepare: ToolPrepareFunc[AgentDeps] | None, diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 746a3ebc69..54b6a29175 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -4,11 +4,11 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast from pydantic import ValidationError from pydantic_core import SchemaValidator -from typing_extensions import Concatenate, ParamSpec, TypeAlias +from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeAliasType, TypeVar from . import _pydantic, _utils, messages as _messages, models from .exceptions import ModelRetry, UnexpectedModelBehavior @@ -30,7 +30,7 @@ 'ToolDefinition', ) -AgentDeps = TypeVar('AgentDeps') +AgentDeps = TypeVar('AgentDeps', default=None) """Type variable for agent dependencies.""" @@ -81,23 +81,31 @@ def replace_with( Usage `SystemPromptFunc[AgentDeps]`. """ -ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any] +ToolFuncContext = TypeAliasType( + 'ToolFuncContext', + Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any], + type_params=(ToolParams, AgentDeps), +) """A tool function that takes `RunContext` as the first argument. -Usage `ToolContextFunc[AgentDeps, ToolParams]`. +Usage `ToolFuncContext[ToolParams, AgentDeps]`. """ ToolFuncPlain = Callable[ToolParams, Any] """A tool function that does not take `RunContext` as the first argument. Usage `ToolPlainFunc[ToolParams]`. """ -ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]] -"""Either part_kind of tool function. +ToolFuncEither = TypeAliasType( + 'ToolFuncEither', + Union[ToolFuncContext[ToolParams, AgentDeps], ToolFuncPlain[ToolParams]], + type_params=(ToolParams, AgentDeps), +) +"""Either kind of tool function. This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain]. -Usage `ToolFuncEither[AgentDeps, ToolParams]`. +Usage `ToolFuncEither[ToolParams, AgentDeps]`. """ ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDeps], ToolDefinition], Awaitable[ToolDefinition | None]]' """Definition of a function that can prepare a tool definition at call time. @@ -134,7 +142,7 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str: class Tool(Generic[AgentDeps]): """A tool function for an agent.""" - function: ToolFuncEither[AgentDeps, ...] + function: ToolFuncEither[..., AgentDeps] takes_ctx: bool max_retries: int | None name: str @@ -150,7 +158,7 @@ class Tool(Generic[AgentDeps]): def __init__( self, - function: ToolFuncEither[AgentDeps, ...], + function: ToolFuncEither[..., AgentDeps], *, takes_ctx: bool | None = None, max_retries: int | None = None, diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 3871037716..040485af92 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -196,7 +196,8 @@ def foobar_plain(x: str, y: int) -> str: Tool(foobar_ctx, takes_ctx=True) Tool(foobar_ctx) Tool(foobar_plain, takes_ctx=False) -Tool(foobar_plain) +assert_type(Tool(foobar_plain), Tool[None]) +assert_type(Tool(foobar_plain), Tool) # unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool Tool(foobar_ctx, takes_ctx=False) @@ -206,12 +207,15 @@ def foobar_plain(x: str, y: int) -> str: Agent('test', tools=[foobar_plain], deps_type=int) Agent('test', tools=[foobar_plain]) Agent('test', tools=[Tool(foobar_ctx)], deps_type=int) -Agent('test', tools=[Tool(foobar_plain)], deps_type=int) Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, foobar_plain], deps_type=int) +Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, Tool(foobar_plain)], deps_type=int) Agent('test', tools=[foobar_ctx], deps_type=str) # pyright: ignore[reportArgumentType] +Agent('test', tools=[Tool(foobar_ctx), Tool(foobar_plain)], deps_type=str) # pyright: ignore[reportArgumentType] Agent('test', tools=[foobar_ctx]) # pyright: ignore[reportArgumentType] Agent('test', tools=[Tool(foobar_ctx)]) # pyright: ignore[reportArgumentType] +# since deps are not set, they default to `None`, so can't be `int` +Agent('test', tools=[Tool(foobar_plain)], deps_type=int) # pyright: ignore[reportArgumentType] # prepare example from docs: @@ -238,6 +242,7 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD default_agent = Agent() assert_type(default_agent, Agent[None, str]) assert_type(default_agent, Agent[None]) + assert_type(default_agent, Agent) partial_agent: Agent[MyDeps] = Agent(deps_type=MyDeps) assert_type(partial_agent, Agent[MyDeps, str]) From d36bca0b7ef211013bb5445142d3bdfa5efefc52 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 2 Jan 2025 11:54:42 -0700 Subject: [PATCH 2/2] Restore original order and add default to ToolParams --- pydantic_ai_slim/pydantic_ai/agent.py | 14 +++++++------- pydantic_ai_slim/pydantic_ai/tools.py | 24 ++++++++---------------- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index d304094cb7..b4af5db766 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -125,7 +125,7 @@ def __init__( result_tool_name: str = 'final_result', result_tool_description: str | None = None, result_retries: int | None = None, - tools: Sequence[Tool[AgentDeps] | ToolFuncEither[..., AgentDeps]] = (), + tools: Sequence[Tool[AgentDeps] | ToolFuncEither[AgentDeps, ...]] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', ): @@ -624,7 +624,7 @@ async def result_validator_deps(ctx: RunContext[str], data: str) -> str: return func @overload - def tool(self, func: ToolFuncContext[ToolParams, AgentDeps], /) -> ToolFuncContext[ToolParams, AgentDeps]: ... + def tool(self, func: ToolFuncContext[AgentDeps, ToolParams], /) -> ToolFuncContext[AgentDeps, ToolParams]: ... @overload def tool( @@ -633,11 +633,11 @@ def tool( *, retries: int | None = None, prepare: ToolPrepareFunc[AgentDeps] | None = None, - ) -> Callable[[ToolFuncContext[ToolParams, AgentDeps]], ToolFuncContext[ToolParams, AgentDeps]]: ... + ) -> Callable[[ToolFuncContext[AgentDeps, ToolParams]], ToolFuncContext[AgentDeps, ToolParams]]: ... def tool( self, - func: ToolFuncContext[ToolParams, AgentDeps] | None = None, + func: ToolFuncContext[AgentDeps, ToolParams] | None = None, /, *, retries: int | None = None, @@ -683,8 +683,8 @@ async def spam(ctx: RunContext[str], y: float) -> float: if func is None: def tool_decorator( - func_: ToolFuncContext[ToolParams, AgentDeps], - ) -> ToolFuncContext[ToolParams, AgentDeps]: + func_: ToolFuncContext[AgentDeps, ToolParams], + ) -> ToolFuncContext[AgentDeps, ToolParams]: # noinspection PyTypeChecker self._register_function(func_, True, retries, prepare) return func_ @@ -766,7 +766,7 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams def _register_function( self, - func: ToolFuncEither[ToolParams, AgentDeps], + func: ToolFuncEither[AgentDeps, ToolParams], takes_ctx: bool, retries: int | None, prepare: ToolPrepareFunc[AgentDeps] | None, diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 54b6a29175..54f7a9affa 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -8,7 +8,7 @@ from pydantic import ValidationError from pydantic_core import SchemaValidator -from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeAliasType, TypeVar +from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar from . import _pydantic, _utils, messages as _messages, models from .exceptions import ModelRetry, UnexpectedModelBehavior @@ -67,7 +67,7 @@ def replace_with( return dataclasses.replace(self, **kwargs) -ToolParams = ParamSpec('ToolParams') +ToolParams = ParamSpec('ToolParams', default=...) """Retrieval function param spec.""" SystemPromptFunc = Union[ @@ -81,31 +81,23 @@ def replace_with( Usage `SystemPromptFunc[AgentDeps]`. """ -ToolFuncContext = TypeAliasType( - 'ToolFuncContext', - Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any], - type_params=(ToolParams, AgentDeps), -) +ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any] """A tool function that takes `RunContext` as the first argument. -Usage `ToolFuncContext[ToolParams, AgentDeps]`. +Usage `ToolContextFunc[AgentDeps, ToolParams]`. """ ToolFuncPlain = Callable[ToolParams, Any] """A tool function that does not take `RunContext` as the first argument. Usage `ToolPlainFunc[ToolParams]`. """ -ToolFuncEither = TypeAliasType( - 'ToolFuncEither', - Union[ToolFuncContext[ToolParams, AgentDeps], ToolFuncPlain[ToolParams]], - type_params=(ToolParams, AgentDeps), -) +ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]] """Either kind of tool function. This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain]. -Usage `ToolFuncEither[ToolParams, AgentDeps]`. +Usage `ToolFuncEither[AgentDeps, ToolParams]`. """ ToolPrepareFunc: TypeAlias = 'Callable[[RunContext[AgentDeps], ToolDefinition], Awaitable[ToolDefinition | None]]' """Definition of a function that can prepare a tool definition at call time. @@ -142,7 +134,7 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str: class Tool(Generic[AgentDeps]): """A tool function for an agent.""" - function: ToolFuncEither[..., AgentDeps] + function: ToolFuncEither[AgentDeps] takes_ctx: bool max_retries: int | None name: str @@ -158,7 +150,7 @@ class Tool(Generic[AgentDeps]): def __init__( self, - function: ToolFuncEither[..., AgentDeps], + function: ToolFuncEither[AgentDeps], *, takes_ctx: bool | None = None, max_retries: int | None = None,