Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 2 additions & 20 deletions pydantic_ai_slim/pydantic_ai/_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from inspect import Parameter, signature
from typing import TYPE_CHECKING, Any, Callable, TypedDict, cast, get_origin

from pydantic import ConfigDict, TypeAdapter
from pydantic import ConfigDict
from pydantic._internal import _decorators, _generate_schema, _typing_extra
from pydantic._internal._config import ConfigWrapper
from pydantic.fields import FieldInfo
Expand All @@ -23,7 +23,7 @@
from .tools import ObjectJsonSchema


__all__ = 'function_schema', 'LazyTypeAdapter'
__all__ = ('function_schema',)


class FunctionSchema(TypedDict):
Expand Down Expand Up @@ -214,21 +214,3 @@ def _is_call_ctx(annotation: Any) -> bool:
return annotation is RunContext or (
_typing_extra.is_generic_alias(annotation) and get_origin(annotation) is RunContext
)


if TYPE_CHECKING:
LazyTypeAdapter = TypeAdapter
else:

class LazyTypeAdapter:
__slots__ = '_args', '_kwargs', '_type_adapter'

def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs
self._type_adapter = None

def __getattr__(self, item):
if self._type_adapter is None:
self._type_adapter = TypeAdapter(*self._args, **self._kwargs)
return getattr(self._type_adapter, item)
9 changes: 5 additions & 4 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pydantic
import pydantic_core

from . import _pydantic
from ._utils import now_utc as _now_utc


Expand Down Expand Up @@ -48,7 +47,7 @@ class UserPrompt:
"""Message type identifier, this type is available on all messages as a discriminator."""


tool_return_ta: pydantic.TypeAdapter[Any] = _pydantic.LazyTypeAdapter(Any)
tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True))


@dataclass
Expand Down Expand Up @@ -84,7 +83,7 @@ def model_response_object(self) -> dict[str, Any]:
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}


ErrorDetailsTa = _pydantic.LazyTypeAdapter(list[pydantic_core.ErrorDetails])
ErrorDetailsTa = pydantic.TypeAdapter(list[pydantic_core.ErrorDetails], config=pydantic.ConfigDict(defer_build=True))


@dataclass
Expand Down Expand Up @@ -220,5 +219,7 @@ def from_text(content: str, timestamp: datetime | None = None) -> ModelResponse:
Message = Union[SystemPrompt, UserPrompt, ToolReturn, RetryPrompt, ModelResponse]
"""Any message send to or returned by a model."""

MessagesTypeAdapter = _pydantic.LazyTypeAdapter(list[Annotated[Message, pydantic.Discriminator('message_kind')]])
MessagesTypeAdapter = pydantic.TypeAdapter(
list[Annotated[Message, pydantic.Discriminator('message_kind')]], config=pydantic.ConfigDict(defer_build=True)
)
"""Pydantic [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] for (de)serializing messages."""
48 changes: 25 additions & 23 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from datetime import datetime
from typing import Annotated, Any, Literal, Protocol, Union

import pydantic
import pydantic_core
from httpx import USE_CLIENT_DEFAULT, AsyncClient as AsyncHTTPClient, Response as HTTPResponse
from pydantic import Discriminator, Field, Tag
from typing_extensions import NotRequired, TypedDict, TypeGuard, assert_never

from .. import UnexpectedModelBehavior, _pydantic, _utils, exceptions, result
from .. import UnexpectedModelBehavior, _utils, exceptions, result
from ..messages import (
ArgsDict,
Message,
Expand Down Expand Up @@ -386,6 +386,7 @@ def timestamp(self) -> datetime:
# TypeAdapters take care of validation and serialization


@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
class _GeminiRequest(TypedDict):
"""Schema for an API request to the Gemini API.

Expand Down Expand Up @@ -457,7 +458,7 @@ class _GeminiTextPart(TypedDict):


class _GeminiFunctionCallPart(TypedDict):
function_call: Annotated[_GeminiFunctionCall, Field(alias='functionCall')]
function_call: Annotated[_GeminiFunctionCall, pydantic.Field(alias='functionCall')]


def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
Expand Down Expand Up @@ -487,7 +488,7 @@ class _GeminiFunctionCall(TypedDict):


class _GeminiFunctionResponsePart(TypedDict):
function_response: Annotated[_GeminiFunctionResponse, Field(alias='functionResponse')]
function_response: Annotated[_GeminiFunctionResponse, pydantic.Field(alias='functionResponse')]


def _response_part_from_response(name: str, response: dict[str, Any]) -> _GeminiFunctionResponsePart:
Expand Down Expand Up @@ -517,11 +518,11 @@ def _part_discriminator(v: Any) -> str:
# TODO discriminator
_GeminiPartUnion = Annotated[
Union[
Annotated[_GeminiTextPart, Tag('text')],
Annotated[_GeminiFunctionCallPart, Tag('function_call')],
Annotated[_GeminiFunctionResponsePart, Tag('function_response')],
Annotated[_GeminiTextPart, pydantic.Tag('text')],
Annotated[_GeminiFunctionCallPart, pydantic.Tag('function_call')],
Annotated[_GeminiFunctionResponsePart, pydantic.Tag('function_response')],
],
Discriminator(_part_discriminator),
pydantic.Discriminator(_part_discriminator),
]


Expand All @@ -531,7 +532,7 @@ class _GeminiTextContent(TypedDict):


class _GeminiTools(TypedDict):
function_declarations: list[Annotated[_GeminiFunction, Field(alias='functionDeclarations')]]
function_declarations: list[Annotated[_GeminiFunction, pydantic.Field(alias='functionDeclarations')]]


class _GeminiFunction(TypedDict):
Expand Down Expand Up @@ -572,6 +573,7 @@ class _GeminiFunctionCallingConfig(TypedDict):
allowed_function_names: list[str]


@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
class _GeminiResponse(TypedDict):
"""Schema for the response from the Gemini API.

Expand All @@ -581,8 +583,8 @@ class _GeminiResponse(TypedDict):

candidates: list[_GeminiCandidates]
# usageMetadata appears to be required by both APIs but is omitted when streaming responses until the last response
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, Field(alias='usageMetadata')]]
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, Field(alias='promptFeedback')]]
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]


# TODO: Delete the next three functions once we've reworked streams to be more flexible
Expand Down Expand Up @@ -618,14 +620,14 @@ class _GeminiCandidates(TypedDict):
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""

content: _GeminiContent
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], Field(alias='finishReason')]]
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS'], pydantic.Field(alias='finishReason')]]
"""
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible,
but let's wait until we see them and know what they mean to add them here.
"""
avg_log_probs: NotRequired[Annotated[float, Field(alias='avgLogProbs')]]
avg_log_probs: NotRequired[Annotated[float, pydantic.Field(alias='avgLogProbs')]]
index: NotRequired[int]
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]]
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]


class _GeminiUsageMetaData(TypedDict, total=False):
Expand All @@ -634,10 +636,10 @@ class _GeminiUsageMetaData(TypedDict, total=False):
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
"""

prompt_token_count: Annotated[int, Field(alias='promptTokenCount')]
candidates_token_count: NotRequired[Annotated[int, Field(alias='candidatesTokenCount')]]
total_token_count: Annotated[int, Field(alias='totalTokenCount')]
cached_content_token_count: NotRequired[Annotated[int, Field(alias='cachedContentTokenCount')]]
prompt_token_count: Annotated[int, pydantic.Field(alias='promptTokenCount')]
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]


def _metadata_as_cost(response: _GeminiResponse) -> result.Cost:
Expand Down Expand Up @@ -671,15 +673,15 @@ class _GeminiSafetyRating(TypedDict):
class _GeminiPromptFeedback(TypedDict):
"""See <https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse>."""

block_reason: Annotated[str, Field(alias='blockReason')]
safety_ratings: Annotated[list[_GeminiSafetyRating], Field(alias='safetyRatings')]
block_reason: Annotated[str, pydantic.Field(alias='blockReason')]
safety_ratings: Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]


_gemini_request_ta = _pydantic.LazyTypeAdapter(_GeminiRequest)
_gemini_response_ta = _pydantic.LazyTypeAdapter(_GeminiResponse)
_gemini_request_ta = pydantic.TypeAdapter(_GeminiRequest)
_gemini_response_ta = pydantic.TypeAdapter(_GeminiResponse)

# steam requests return a list of https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
_gemini_streamed_response_ta = _pydantic.LazyTypeAdapter(list[_GeminiResponse])
_gemini_streamed_response_ta = pydantic.TypeAdapter(list[_GeminiResponse], config=pydantic.ConfigDict(defer_build=True))


class _GeminiJsonSchema:
Expand Down
Loading