Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast

from pydantic import ValidationError
from pydantic_core import SchemaValidator
from typing_extensions import Concatenate, ParamSpec, TypeAlias
from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar

from . import _pydantic, _utils, messages as _messages, models
from .exceptions import ModelRetry, UnexpectedModelBehavior
Expand All @@ -30,7 +30,7 @@
'ToolDefinition',
)

AgentDeps = TypeVar('AgentDeps')
AgentDeps = TypeVar('AgentDeps', default=None)
"""Type variable for agent dependencies."""


Expand Down Expand Up @@ -67,7 +67,7 @@ def replace_with(
return dataclasses.replace(self, **kwargs)


ToolParams = ParamSpec('ToolParams')
ToolParams = ParamSpec('ToolParams', default=...)
"""Retrieval function param spec."""

SystemPromptFunc = Union[
Expand All @@ -92,7 +92,7 @@ def replace_with(
Usage `ToolPlainFunc[ToolParams]`.
"""
ToolFuncEither = Union[ToolFuncContext[AgentDeps, ToolParams], ToolFuncPlain[ToolParams]]
"""Either part_kind of tool function.
"""Either kind of tool function.

This is just a union of [`ToolFuncContext`][pydantic_ai.tools.ToolFuncContext] and
[`ToolFuncPlain`][pydantic_ai.tools.ToolFuncPlain].
Expand Down Expand Up @@ -134,7 +134,7 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str:
class Tool(Generic[AgentDeps]):
"""A tool function for an agent."""

function: ToolFuncEither[AgentDeps, ...]
function: ToolFuncEither[AgentDeps]
takes_ctx: bool
max_retries: int | None
name: str
Expand All @@ -150,7 +150,7 @@ class Tool(Generic[AgentDeps]):

def __init__(
self,
function: ToolFuncEither[AgentDeps, ...],
function: ToolFuncEither[AgentDeps],
*,
takes_ctx: bool | None = None,
max_retries: int | None = None,
Expand Down
9 changes: 7 additions & 2 deletions tests/typed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def foobar_plain(x: str, y: int) -> str:
Tool(foobar_ctx, takes_ctx=True)
Tool(foobar_ctx)
Tool(foobar_plain, takes_ctx=False)
Tool(foobar_plain)
assert_type(Tool(foobar_plain), Tool[None])
assert_type(Tool(foobar_plain), Tool)

# unfortunately we can't type check these cases, since from a typing perspect `foobar_ctx` is valid as a plain tool
Tool(foobar_ctx, takes_ctx=False)
Expand All @@ -206,12 +207,15 @@ def foobar_plain(x: str, y: int) -> str:
Agent('test', tools=[foobar_plain], deps_type=int)
Agent('test', tools=[foobar_plain])
Agent('test', tools=[Tool(foobar_ctx)], deps_type=int)
Agent('test', tools=[Tool(foobar_plain)], deps_type=int)
Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, foobar_plain], deps_type=int)
Agent('test', tools=[Tool(foobar_ctx), foobar_ctx, Tool(foobar_plain)], deps_type=int)

Agent('test', tools=[foobar_ctx], deps_type=str) # pyright: ignore[reportArgumentType]
Agent('test', tools=[Tool(foobar_ctx), Tool(foobar_plain)], deps_type=str) # pyright: ignore[reportArgumentType]
Agent('test', tools=[foobar_ctx]) # pyright: ignore[reportArgumentType]
Agent('test', tools=[Tool(foobar_ctx)]) # pyright: ignore[reportArgumentType]
# since deps are not set, they default to `None`, so can't be `int`
Agent('test', tools=[Tool(foobar_plain)], deps_type=int) # pyright: ignore[reportArgumentType]

# prepare example from docs:

Expand All @@ -238,6 +242,7 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD
default_agent = Agent()
assert_type(default_agent, Agent[None, str])
assert_type(default_agent, Agent[None])
assert_type(default_agent, Agent)

partial_agent: Agent[MyDeps] = Agent(deps_type=MyDeps)
assert_type(partial_agent, Agent[MyDeps, str])
Expand Down
Loading