Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,14 @@ You can then use [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] by name:
```python {title="gemini_model_by_name.py"}
from pydantic_ai import Agent

agent = Agent('gemini-1.5-flash')
agent = Agent('google-gla:gemini-1.5-flash')
...
```

!!! note
The `google-gla` provider prefix represents the [Google **G**enerative **L**anguage **A**PI](https://ai.google.dev/api/all-methods) for `GeminiModel`s.
`google-vertex` is used with [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) for `VertexAIModel`s.

Or initialise the model directly with just the model name:

```python {title="gemini_model_init.py"}
Expand Down
2 changes: 1 addition & 1 deletion examples/pydantic_ai_examples/sql_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class InvalidRequest(BaseModel):

Response: TypeAlias = Union[Success, InvalidRequest]
agent: Agent[Deps, Response] = Agent(
'gemini-1.5-flash',
'google-gla:gemini-1.5-flash',
# Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
result_type=Response, # type: ignore
deps_type=Deps,
Expand Down
2 changes: 1 addition & 1 deletion examples/pydantic_ai_examples/stream_markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# models to try, and the appropriate env var
models: list[tuple[KnownModelName, str]] = [
('gemini-1.5-flash', 'GEMINI_API_KEY'),
('google-gla:gemini-1.5-flash', 'GEMINI_API_KEY'),
('openai:gpt-4o-mini', 'OPENAI_API_KEY'),
('groq:llama-3.1-70b-versatile', 'GROQ_API_KEY'),
]
Expand Down
38 changes: 28 additions & 10 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@
'groq:mixtral-8x7b-32768',
'groq:gemma2-9b-it',
'groq:gemma-7b-it',
'gemini-1.5-flash',
'gemini-1.5-pro',
'gemini-2.0-flash-exp',
'vertexai:gemini-1.5-flash',
'vertexai:gemini-1.5-pro',
# since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
'google-gla:gemini-1.5-flash',
'google-gla:gemini-1.5-pro',
'google-gla:gemini-2.0-flash-exp',
'google-vertex:gemini-1.5-flash',
'google-vertex:gemini-1.5-pro',
'google-vertex:gemini-2.0-flash-exp',
'mistral:mistral-small-latest',
'mistral:mistral-large-latest',
'mistral:codestral-latest',
Expand All @@ -76,9 +75,9 @@
'ollama:qwen2',
'ollama:qwen2.5',
'ollama:starcoder2',
'claude-3-5-haiku-latest',
'claude-3-5-sonnet-latest',
'claude-3-opus-latest',
'anthropic:claude-3-5-haiku-latest',
'anthropic:claude-3-5-sonnet-latest',
'anthropic:claude-3-opus-latest',
'test',
]
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
Expand Down Expand Up @@ -274,6 +273,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .openai import OpenAIModel

return OpenAIModel(model[7:])
elif model.startswith(('gpt', 'o1')):
from .openai import OpenAIModel

return OpenAIModel(model)
elif model.startswith('google-gla'):
from .gemini import GeminiModel

return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
elif model.startswith('gemini'):
from .gemini import GeminiModel

Expand All @@ -283,6 +291,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .groq import GroqModel

return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
elif model.startswith('google-vertex'):
from .vertexai import VertexAIModel

return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
elif model.startswith('vertexai:'):
from .vertexai import VertexAIModel

Expand All @@ -295,6 +308,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
from .ollama import OllamaModel

return OllamaModel(model[7:])
elif model.startswith('anthropic'):
from .anthropic import AnthropicModel

return AnthropicModel(model[10:])
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
elif model.startswith('claude'):
from .anthropic import AnthropicModel

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def agent_model(
)

def name(self) -> str:
return self.model_name
return f'anthropic:{self.model_name}'

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ToolParam:
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def agent_model(
)

def name(self) -> str:
return self.model_name
return f'google-gla:{self.model_name}'


class AuthProtocol(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def ainit(self) -> tuple[str, BearerTokenAuth]:
return url, auth

def name(self) -> str:
return f'vertexai:{self.model_name}'
return f'google-vertex:{self.model_name}'


# pyright: reportUnknownMemberType=false
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
def test_init():
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
assert m.client.api_key == 'foobar'
assert m.name() == 'claude-3-5-haiku-latest'
assert m.name() == 'anthropic:claude-3-5-haiku-latest'


@dataclass
Expand Down
93 changes: 67 additions & 26 deletions tests/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,84 @@
from importlib import import_module

import pytest

from pydantic_ai import UserError
from pydantic_ai.models import infer_model
from pydantic_ai.models.gemini import GeminiModel

from ..conftest import TestEnv, try_import
from ..conftest import TestEnv

TEST_CASES = [
('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
('OPENAI_API_KEY', 'o1', 'openai:o1', 'openai', 'OpenAIModel'),
('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
(
'GEMINI_API_KEY',
'google-vertex:gemini-1.5-flash',
'google-vertex:gemini-1.5-flash',
'vertexai',
'VertexAIModel',
),
(
'GEMINI_API_KEY',
'vertexai:gemini-1.5-flash',
'google-vertex:gemini-1.5-flash',
'vertexai',
'VertexAIModel',
),
(
'ANTHROPIC_API_KEY',
'anthropic:claude-3-5-haiku-latest',
'anthropic:claude-3-5-haiku-latest',
'anthropic',
'AnthropicModel',
),
(
'ANTHROPIC_API_KEY',
'claude-3-5-haiku-latest',
'anthropic:claude-3-5-haiku-latest',
'anthropic',
'AnthropicModel',
),
(
'GROQ_API_KEY',
'groq:llama-3.3-70b-versatile',
'groq:llama-3.3-70b-versatile',
'groq',
'GroqModel',
),
('OLLAMA_API_KEY', 'ollama:llama3', 'ollama:llama3', 'ollama', 'OllamaModel'),
(
'MISTRAL_API_KEY',
'mistral:mistral-small-latest',
'mistral:mistral-small-latest',
'mistral',
'MistralModel',
),
]

with try_import() as openai_imports_successful:
from pydantic_ai.models.openai import OpenAIModel

with try_import() as vertexai_imports_successful:
from pydantic_ai.models.vertexai import VertexAIModel
@pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, module_name, model_class_name', TEST_CASES)
def test_infer_model(
env: TestEnv, mock_api_key: str, model_name: str, expected_model_name: str, module_name: str, model_class_name: str
):
try:
model_module = import_module(f'pydantic_ai.models.{module_name}')
expected_model = getattr(model_module, model_class_name)
except ImportError:
pytest.skip(f'{model_name} dependencies not installed')

env.set(mock_api_key, 'via-env-var')

@pytest.mark.skipif(not openai_imports_successful(), reason='openai not installed')
def test_infer_str_openai(env: TestEnv):
env.set('OPENAI_API_KEY', 'via-env-var')
m = infer_model('openai:gpt-3.5-turbo')
assert isinstance(m, OpenAIModel)
assert m.name() == 'openai:gpt-3.5-turbo'
m = infer_model(model_name) # pyright: ignore[reportArgumentType]
assert isinstance(m, expected_model)
assert m.name() == expected_model_name

m2 = infer_model(m)
assert m2 is m


def test_infer_str_gemini(env: TestEnv):
env.set('GEMINI_API_KEY', 'via-env-var')
m = infer_model('gemini-1.5-flash')
assert isinstance(m, GeminiModel)
assert m.name() == 'gemini-1.5-flash'


@pytest.mark.skipif(not vertexai_imports_successful(), reason='google-auth not installed')
def test_infer_vertexai(env: TestEnv):
m = infer_model('vertexai:gemini-1.5-flash')
assert isinstance(m, VertexAIModel)
assert m.name() == 'vertexai:gemini-1.5-flash'


def test_infer_str_unknown():
with pytest.raises(UserError, match='Unknown model: foobar'):
infer_model('foobar') # pyright: ignore[reportArgumentType]
4 changes: 2 additions & 2 deletions tests/models/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
'publishers/google/models/gemini-1.5-flash:'
)
assert model.auth is not None
assert model.name() == snapshot('vertexai:gemini-1.5-flash')
assert model.name() == snapshot('google-vertex:gemini-1.5-flash')


class NoOpCredentials:
Expand All @@ -67,7 +67,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
'publishers/google/models/gemini-1.5-flash:'
)
assert model.auth is not None
assert model.name() == snapshot('vertexai:gemini-1.5-flash')
assert model.name() == snapshot('google-vertex:gemini-1.5-flash')

await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
assert model.url is not None
Expand Down
4 changes: 2 additions & 2 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,15 +780,15 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse:

def test_model_requests_blocked(env: TestEnv, set_event_loop: None):
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)
agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True)

with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'):
agent.run_sync('Hello')


def test_override_model(env: TestEnv, set_event_loop: None):
env.set('GEMINI_API_KEY', 'foobar')
agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)
agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True)

with agent.override(model='test'):
result = agent.run_sync('Hello')
Expand Down
Loading