Skip to content

Commit 7920cb9

Browse files
Prefix all models with provider for consistency (#593)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent 34ec8a6 commit 7920cb9

File tree

11 files changed

+110
-47
lines changed

11 files changed

+110
-47
lines changed

docs/models.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,14 @@ You can then use [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] by name:
184184
```python {title="gemini_model_by_name.py"}
185185
from pydantic_ai import Agent
186186

187-
agent = Agent('gemini-1.5-flash')
187+
agent = Agent('google-gla:gemini-1.5-flash')
188188
...
189189
```
190190

191+
!!! note
192+
The `google-gla` provider prefix represents the [Google **G**enerative **L**anguage **A**PI](https://ai.google.dev/api/all-methods) for `GeminiModel`s.
193+
`google-vertex` is used with [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) for `VertexAIModel`s.
194+
191195
Or initialise the model directly with just the model name:
192196

193197
```python {title="gemini_model_init.py"}

examples/pydantic_ai_examples/sql_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class InvalidRequest(BaseModel):
9393

9494
Response: TypeAlias = Union[Success, InvalidRequest]
9595
agent: Agent[Deps, Response] = Agent(
96-
'gemini-1.5-flash',
96+
'google-gla:gemini-1.5-flash',
9797
# Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else
9898
result_type=Response, # type: ignore
9999
deps_type=Deps,

examples/pydantic_ai_examples/stream_markdown.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
# models to try, and the appropriate env var
2727
models: list[tuple[KnownModelName, str]] = [
28-
('gemini-1.5-flash', 'GEMINI_API_KEY'),
28+
('google-gla:gemini-1.5-flash', 'GEMINI_API_KEY'),
2929
('openai:gpt-4o-mini', 'OPENAI_API_KEY'),
3030
('groq:llama-3.1-70b-versatile', 'GROQ_API_KEY'),
3131
]

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,12 @@
4848
'groq:mixtral-8x7b-32768',
4949
'groq:gemma2-9b-it',
5050
'groq:gemma-7b-it',
51-
'gemini-1.5-flash',
52-
'gemini-1.5-pro',
53-
'gemini-2.0-flash-exp',
54-
'vertexai:gemini-1.5-flash',
55-
'vertexai:gemini-1.5-pro',
56-
# since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral")
57-
# don't start with "mistral", we add the "mistral:" prefix to all to be explicit
51+
'google-gla:gemini-1.5-flash',
52+
'google-gla:gemini-1.5-pro',
53+
'google-gla:gemini-2.0-flash-exp',
54+
'google-vertex:gemini-1.5-flash',
55+
'google-vertex:gemini-1.5-pro',
56+
'google-vertex:gemini-2.0-flash-exp',
5857
'mistral:mistral-small-latest',
5958
'mistral:mistral-large-latest',
6059
'mistral:codestral-latest',
@@ -76,9 +75,9 @@
7675
'ollama:qwen2',
7776
'ollama:qwen2.5',
7877
'ollama:starcoder2',
79-
'claude-3-5-haiku-latest',
80-
'claude-3-5-sonnet-latest',
81-
'claude-3-opus-latest',
78+
'anthropic:claude-3-5-haiku-latest',
79+
'anthropic:claude-3-5-sonnet-latest',
80+
'anthropic:claude-3-opus-latest',
8281
'test',
8382
]
8483
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
@@ -274,6 +273,15 @@ def infer_model(model: Model | KnownModelName) -> Model:
274273
from .openai import OpenAIModel
275274

276275
return OpenAIModel(model[7:])
276+
elif model.startswith(('gpt', 'o1')):
277+
from .openai import OpenAIModel
278+
279+
return OpenAIModel(model)
280+
elif model.startswith('google-gla'):
281+
from .gemini import GeminiModel
282+
283+
return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType]
284+
# backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash)
277285
elif model.startswith('gemini'):
278286
from .gemini import GeminiModel
279287

@@ -283,6 +291,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
283291
from .groq import GroqModel
284292

285293
return GroqModel(model[5:]) # pyright: ignore[reportArgumentType]
294+
elif model.startswith('google-vertex'):
295+
from .vertexai import VertexAIModel
296+
297+
return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType]
298+
# backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash)
286299
elif model.startswith('vertexai:'):
287300
from .vertexai import VertexAIModel
288301

@@ -295,6 +308,11 @@ def infer_model(model: Model | KnownModelName) -> Model:
295308
from .ollama import OllamaModel
296309

297310
return OllamaModel(model[7:])
311+
elif model.startswith('anthropic'):
312+
from .anthropic import AnthropicModel
313+
314+
return AnthropicModel(model[10:])
315+
# backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest)
298316
elif model.startswith('claude'):
299317
from .anthropic import AnthropicModel
300318

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def agent_model(
136136
)
137137

138138
def name(self) -> str:
139-
return self.model_name
139+
return f'anthropic:{self.model_name}'
140140

141141
@staticmethod
142142
def _map_tool_definition(f: ToolDefinition) -> ToolParam:

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async def agent_model(
111111
)
112112

113113
def name(self) -> str:
114-
return self.model_name
114+
return f'google-gla:{self.model_name}'
115115

116116

117117
class AuthProtocol(Protocol):

pydantic_ai_slim/pydantic_ai/models/vertexai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ async def ainit(self) -> tuple[str, BearerTokenAuth]:
164164
return url, auth
165165

166166
def name(self) -> str:
167-
return f'vertexai:{self.model_name}'
167+
return f'google-vertex:{self.model_name}'
168168

169169

170170
# pyright: reportUnknownMemberType=false

tests/models/test_anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
def test_init():
4646
m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar')
4747
assert m.client.api_key == 'foobar'
48-
assert m.name() == 'claude-3-5-haiku-latest'
48+
assert m.name() == 'anthropic:claude-3-5-haiku-latest'
4949

5050

5151
@dataclass

tests/models/test_model.py

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,84 @@
1+
from importlib import import_module
2+
13
import pytest
24

35
from pydantic_ai import UserError
46
from pydantic_ai.models import infer_model
5-
from pydantic_ai.models.gemini import GeminiModel
67

7-
from ..conftest import TestEnv, try_import
8+
from ..conftest import TestEnv
9+
10+
TEST_CASES = [
11+
('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
12+
('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'),
13+
('OPENAI_API_KEY', 'o1', 'openai:o1', 'openai', 'OpenAIModel'),
14+
('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
15+
('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'),
16+
(
17+
'GEMINI_API_KEY',
18+
'google-vertex:gemini-1.5-flash',
19+
'google-vertex:gemini-1.5-flash',
20+
'vertexai',
21+
'VertexAIModel',
22+
),
23+
(
24+
'GEMINI_API_KEY',
25+
'vertexai:gemini-1.5-flash',
26+
'google-vertex:gemini-1.5-flash',
27+
'vertexai',
28+
'VertexAIModel',
29+
),
30+
(
31+
'ANTHROPIC_API_KEY',
32+
'anthropic:claude-3-5-haiku-latest',
33+
'anthropic:claude-3-5-haiku-latest',
34+
'anthropic',
35+
'AnthropicModel',
36+
),
37+
(
38+
'ANTHROPIC_API_KEY',
39+
'claude-3-5-haiku-latest',
40+
'anthropic:claude-3-5-haiku-latest',
41+
'anthropic',
42+
'AnthropicModel',
43+
),
44+
(
45+
'GROQ_API_KEY',
46+
'groq:llama-3.3-70b-versatile',
47+
'groq:llama-3.3-70b-versatile',
48+
'groq',
49+
'GroqModel',
50+
),
51+
('OLLAMA_API_KEY', 'ollama:llama3', 'ollama:llama3', 'ollama', 'OllamaModel'),
52+
(
53+
'MISTRAL_API_KEY',
54+
'mistral:mistral-small-latest',
55+
'mistral:mistral-small-latest',
56+
'mistral',
57+
'MistralModel',
58+
),
59+
]
860

9-
with try_import() as openai_imports_successful:
10-
from pydantic_ai.models.openai import OpenAIModel
1161

12-
with try_import() as vertexai_imports_successful:
13-
from pydantic_ai.models.vertexai import VertexAIModel
62+
@pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, module_name, model_class_name', TEST_CASES)
63+
def test_infer_model(
64+
env: TestEnv, mock_api_key: str, model_name: str, expected_model_name: str, module_name: str, model_class_name: str
65+
):
66+
try:
67+
model_module = import_module(f'pydantic_ai.models.{module_name}')
68+
expected_model = getattr(model_module, model_class_name)
69+
except ImportError:
70+
pytest.skip(f'{model_name} dependencies not installed')
1471

72+
env.set(mock_api_key, 'via-env-var')
1573

16-
@pytest.mark.skipif(not openai_imports_successful(), reason='openai not installed')
17-
def test_infer_str_openai(env: TestEnv):
18-
env.set('OPENAI_API_KEY', 'via-env-var')
19-
m = infer_model('openai:gpt-3.5-turbo')
20-
assert isinstance(m, OpenAIModel)
21-
assert m.name() == 'openai:gpt-3.5-turbo'
74+
m = infer_model(model_name) # pyright: ignore[reportArgumentType]
75+
assert isinstance(m, expected_model)
76+
assert m.name() == expected_model_name
2277

2378
m2 = infer_model(m)
2479
assert m2 is m
2580

2681

27-
def test_infer_str_gemini(env: TestEnv):
28-
env.set('GEMINI_API_KEY', 'via-env-var')
29-
m = infer_model('gemini-1.5-flash')
30-
assert isinstance(m, GeminiModel)
31-
assert m.name() == 'gemini-1.5-flash'
32-
33-
34-
@pytest.mark.skipif(not vertexai_imports_successful(), reason='google-auth not installed')
35-
def test_infer_vertexai(env: TestEnv):
36-
m = infer_model('vertexai:gemini-1.5-flash')
37-
assert isinstance(m, VertexAIModel)
38-
assert m.name() == 'vertexai:gemini-1.5-flash'
39-
40-
4182
def test_infer_str_unknown():
4283
with pytest.raises(UserError, match='Unknown model: foobar'):
4384
infer_model('foobar') # pyright: ignore[reportArgumentType]

tests/models/test_vertexai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
4040
'publishers/google/models/gemini-1.5-flash:'
4141
)
4242
assert model.auth is not None
43-
assert model.name() == snapshot('vertexai:gemini-1.5-flash')
43+
assert model.name() == snapshot('google-vertex:gemini-1.5-flash')
4444

4545

4646
class NoOpCredentials:
@@ -67,7 +67,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
6767
'publishers/google/models/gemini-1.5-flash:'
6868
)
6969
assert model.auth is not None
70-
assert model.name() == snapshot('vertexai:gemini-1.5-flash')
70+
assert model.name() == snapshot('google-vertex:gemini-1.5-flash')
7171

7272
await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[])
7373
assert model.url is not None

0 commit comments

Comments
 (0)