diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 746a3ebc69..54f7a9affa 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -4,11 +4,11 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast from pydantic import ValidationError from pydantic_core import SchemaValidator -from typing_extensions import Concatenate, ParamSpec, TypeAlias +from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar from . import _pydantic, _utils, messages as _messages, models from .exceptions import ModelRetry, UnexpectedModelBehavior @@ -30,7 +30,7 @@ 'ToolDefinition', ) -AgentDeps = TypeVar('AgentDeps') +AgentDeps = TypeVar('AgentDeps', default=None) """Type variable for agent dependencies.""" @@ -67,7 +67,7 @@ def replace_with( return dataclasses.replace(self, **kwargs) -ToolParams = ParamSpec('ToolParams') +ToolParams = ParamSpec('ToolParams', default=...) """Retrieval function param spec.""" SystemPromptFunc = Union[ @@ -92,7 +92,7 @@ def replace_with( Usage `ToolPlainFunc[ToolParams]`. """ ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]] -"""Either part_kind of tool function. +"""Either kind of tool function. This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and [`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain]. @@ -134,7 +134,7 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str: class Tool(Generic[AgentDeps]): """A tool function for an agent.""" - function: ToolFuncEither[AgentDeps, ...] + function: ToolFuncEither[AgentDeps] takes_ctx: bool max_retries: int | None name: str @@ -150,7 +150,7 @@ class Tool(Generic[AgentDeps]): def __init__( self, - function: ToolFuncEither[AgentDeps, ...], + function: ToolFuncEither[AgentDeps], *, takes_ctx: bool | None = None, max_retries: int | None = None, diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 3871037716..040485af92 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -196,7 +196,8 @@ def foobar_plain(x: str, y: int) -> str: Tool(foobar_ctx, takes_ctx=True) Tool(foobar_ctx) Tool(foobar_plain, takes_ctx=False) -Tool(foobar_plain) +assert_type(Tool(foobar_plain), Tool[None]) +assert_type(Tool(foobar_plain), Tool) # unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool Tool(foobar_ctx, takes_ctx=False) @@ -206,12 +207,15 @@ def foobar_plain(x: str, y: int) -> str: Agent('test', tools=[foobar_plain], deps_type=int) Agent('test', tools=[foobar_plain]) Agent('test', tools=[Tool(foobar_ctx)], deps_type=int) -Agent('test', tools=[Tool(foobar_plain)], deps_type=int) Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, foobar_plain], deps_type=int) +Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, Tool(foobar_plain)], deps_type=int) Agent('test', tools=[foobar_ctx], deps_type=str) # pyright: ignore[reportArgumentType] +Agent('test', tools=[Tool(foobar_ctx), Tool(foobar_plain)], deps_type=str) # pyright: ignore[reportArgumentType] Agent('test', tools=[foobar_ctx]) # pyright: ignore[reportArgumentType] Agent('test', tools=[Tool(foobar_ctx)]) # pyright: ignore[reportArgumentType] +# since deps are not set, they default to `None`, so can't be `int` +Agent('test', tools=[Tool(foobar_plain)], deps_type=int) # pyright: ignore[reportArgumentType] # prepare example from docs: @@ -238,6 +242,7 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD default_agent = Agent() assert_type(default_agent, Agent[None, str]) assert_type(default_agent, Agent[None]) + assert_type(default_agent, Agent) partial_agent: Agent[MyDeps] = Agent(deps_type=MyDeps) assert_type(partial_agent, Agent[MyDeps, str])