Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

from typing_extensions import assert_never

from pydantic_ai.providers import Provider

from .. import UnexpectedModelBehavior, _utils, usage
from ..exceptions import UserError
from ..messages import (
BinaryContent,
FileUrl,
Expand All @@ -30,6 +29,7 @@
VideoUrl,
)
from ..profiles import ModelProfileSpec
from ..providers import Provider
from ..settings import ModelSettings
from ..tools import ToolDefinition
from . import (
Expand All @@ -52,6 +52,7 @@
FunctionDeclarationDict,
GenerateContentConfigDict,
GenerateContentResponse,
HttpOptionsDict,
Part,
PartDict,
SafetySettingDict,
Expand Down Expand Up @@ -252,8 +253,17 @@ async def _generate_content(
tool_config = self._get_tool_config(model_request_parameters, tools)
system_instruction, contents = await self._map_messages(messages)

http_options: HttpOptionsDict = {
'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
}
if timeout := model_settings.get('timeout'):
if isinstance(timeout, (int, float)):
http_options['timeout'] = int(1000 * timeout)
else:
raise UserError('Google does not support setting ModelSettings.timeout to a httpx.Timeout')

config = GenerateContentConfigDict(
http_options={'headers': {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}},
http_options=http_options,
system_instruction=system_instruction,
temperature=model_settings.get('temperature'),
top_p=model_settings.get('top_p'),
Expand Down
64 changes: 64 additions & 0 deletions tests/models/cassettes/test_google/test_google_timeout.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
interactions:
- request:
headers:
accept:
- '*/*'
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '87'
content-type:
- application/json
host:
- generativelanguage.googleapis.com
method: POST
parsed_body:
contents:
- parts:
- text: Hello!
role: user
generationConfig: {}
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent
response:
headers:
alt-svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
content-length:
- '687'
content-type:
- application/json; charset=UTF-8
server-timing:
- gfet4t7; dur=304
transfer-encoding:
- chunked
vary:
- Origin
- X-Origin
- Referer
parsed_body:
candidates:
- avgLogprobs: -0.0009780019860376012
content:
parts:
- text: |
Hello there! How can I help you today?
role: model
finishReason: STOP
modelVersion: gemini-1.5-flash
responseId: KKRRaJyZNJiFm9IPzMfNiAk
usageMetadata:
candidatesTokenCount: 11
candidatesTokensDetails:
- modality: TEXT
tokenCount: 11
promptTokenCount: 2
promptTokensDetails:
- modality: TEXT
tokenCount: 2
totalTokenCount: 13
status:
code: 200
message: OK
version: 1
13 changes: 12 additions & 1 deletion tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Union

import pytest
from httpx import Request
from httpx import Request, Timeout
from inline_snapshot import Is, snapshot
from pytest_mock import MockerFixture
from typing_extensions import TypedDict
Expand Down Expand Up @@ -806,3 +806,14 @@ async def bar() -> str:
),
]
)


async def test_google_timeout(allow_model_requests: None, google_provider: GoogleProvider):
model = GoogleModel('gemini-1.5-flash', provider=google_provider)
agent = Agent(model=model)

result = await agent.run('Hello!', model_settings={'timeout': 10})
assert result.output == snapshot('Hello there! How can I help you today?\n')

with pytest.raises(UserError, match='Google does not support setting ModelSettings.timeout to a httpx.Timeout'):
await agent.run('Hello!', model_settings={'timeout': Timeout(10)})