From 1ae39da082d8d9efce454e2efb00aac3c274f868 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 2 Jan 2025 13:53:37 -0500 Subject: [PATCH 1/8] prefix models with provider --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 10 +++++----- pydantic_ai_slim/pydantic_ai/models/anthropic.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 1a1bc9f47b..130f34570b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -76,9 +76,9 @@ 'ollama:qwen2', 'ollama:qwen2.5', 'ollama:starcoder2', - 'claude-3-5-haiku-latest', - 'claude-3-5-sonnet-latest', - 'claude-3-opus-latest', + 'anthropic:claude-3-5-haiku-latest', + 'anthropic:claude-3-5-sonnet-latest', + 'anthropic:claude-3-opus-latest', 'test', ] """Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent]. @@ -295,10 +295,10 @@ def infer_model(model: Model | KnownModelName) -> Model: from .ollama import OllamaModel return OllamaModel(model[7:]) - elif model.startswith('claude'): + elif model.startswith('anthropic'): from .anthropic import AnthropicModel - return AnthropicModel(model) + return AnthropicModel(model[10:]) else: raise UserError(f'Unknown model: {model}') diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index a49f3eb4cb..f4674812b6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -136,7 +136,7 @@ async def agent_model( ) def name(self) -> str: - return self.model_name + return f'anthropic:{self.model_name}' @staticmethod def _map_tool_definition(f: ToolDefinition) -> ToolParam: From 87f3d9e76548c1efa16fac24bd605ab007680802 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 2 Jan 2025 13:54:28 -0500 Subject: [PATCH 2/8] fix test --- tests/models/test_anthropic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 1fad2628e7..66a16605ae 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -45,7 +45,7 @@ def test_init(): m = AnthropicModel('claude-3-5-haiku-latest', api_key='foobar') assert m.client.api_key == 'foobar' - assert m.name() == 'claude-3-5-haiku-latest' + assert m.name() == 'anthropic:claude-3-5-haiku-latest' @dataclass From 0e030dd610613bc2fd531aea57a36e813a6d5b2c Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 2 Jan 2025 14:21:04 -0500 Subject: [PATCH 3/8] prefix with provider consistently --- docs/models.md | 6 +++- examples/pydantic_ai_examples/sql_gen.py | 2 +- .../pydantic_ai_examples/stream_markdown.py | 2 +- .../pydantic_ai/models/__init__.py | 32 +++++++++++++++---- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- .../pydantic_ai/models/vertexai.py | 2 +- tests/models/test_model.py | 8 ++--- tests/models/test_vertexai.py | 4 +-- tests/test_agent.py | 4 +-- 9 files changed, 42 insertions(+), 20 deletions(-) diff --git a/docs/models.md b/docs/models.md index af27118141..3ca1fb65c9 100644 --- a/docs/models.md +++ b/docs/models.md @@ -199,10 +199,14 @@ You can then use [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] by name: ```python {title="gemini_model_by_name.py"} from pydantic_ai import Agent -agent = Agent('gemini-1.5-flash') +agent = Agent('google-gla:gemini-1.5-flash') ... ``` +!!! note + The `google-gla` provider prefix represents the "google generative language api" for `GeminiModel`s. + `google-vertex` is used with `VertexAIModel`s. + Or initialise the model directly with just the model name: ```python {title="gemini_model_init.py"} diff --git a/examples/pydantic_ai_examples/sql_gen.py b/examples/pydantic_ai_examples/sql_gen.py index f636fffe23..e38febf21a 100644 --- a/examples/pydantic_ai_examples/sql_gen.py +++ b/examples/pydantic_ai_examples/sql_gen.py @@ -74,7 +74,7 @@ class InvalidRequest(BaseModel): Response: TypeAlias = Union[Success, InvalidRequest] agent: Agent[Deps, Response] = Agent( - 'gemini-1.5-flash', + 'google-gla:gemini-1.5-flash', # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else result_type=Response, # type: ignore deps_type=Deps, diff --git a/examples/pydantic_ai_examples/stream_markdown.py b/examples/pydantic_ai_examples/stream_markdown.py index bc2362f724..9d95cae09d 100644 --- a/examples/pydantic_ai_examples/stream_markdown.py +++ b/examples/pydantic_ai_examples/stream_markdown.py @@ -25,7 +25,7 @@ # models to try, and the appropriate env var models: list[tuple[KnownModelName, str]] = [ - ('gemini-1.5-flash', 'GEMINI_API_KEY'), + ('google-gla:gemini-1.5-flash', 'GEMINI_API_KEY'), ('openai:gpt-4o-mini', 'OPENAI_API_KEY'), ('groq:llama-3.1-70b-versatile', 'GROQ_API_KEY'), ] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 130f34570b..61b310996f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -48,13 +48,12 @@ 'groq:mixtral-8x7b-32768', 'groq:gemma2-9b-it', 'groq:gemma-7b-it', - 'gemini-1.5-flash', - 'gemini-1.5-pro', - 'gemini-2.0-flash-exp', - 'vertexai:gemini-1.5-flash', - 'vertexai:gemini-1.5-pro', - # since mistral models are supported by other providers (e.g. ollama), and some of their models (e.g. "codestral") - # don't start with "mistral", we add the "mistral:" prefix to all to be explicit + 'google-gla:gemini-1.5-flash', + 'google-gla:gemini-1.5-pro', + 'google-gla:gemini-2.0-flash-exp', + 'google-vertex:gemini-1.5-flash', + 'google-vertex:gemini-1.5-pro', + 'google-vertex:gemini-2.0-flash-exp', 'mistral:mistral-small-latest', 'mistral:mistral-large-latest', 'mistral:codestral-latest', @@ -274,6 +273,15 @@ def infer_model(model: Model | KnownModelName) -> Model: from .openai import OpenAIModel return OpenAIModel(model[7:]) + elif model.startswith(('gpt', 'o1')): + from .openai import OpenAIModel + + return OpenAIModel(model) + elif model.startswith('google-gla'): + from .gemini import GeminiModel + + return GeminiModel(model[11:]) # pyright: ignore[reportArgumentType] + # backwards compatibility with old model names (ex, gemini-1.5-flash -> google-gla:gemini-1.5-flash) elif model.startswith('gemini'): from .gemini import GeminiModel @@ -283,6 +291,11 @@ def infer_model(model: Model | KnownModelName) -> Model: from .groq import GroqModel return GroqModel(model[5:]) # pyright: ignore[reportArgumentType] + elif model.startswith('google-vertex'): + from .vertexai import VertexAIModel + + return VertexAIModel(model[14:]) # pyright: ignore[reportArgumentType] + # backwards compatibility with old model names (ex, vertexai:gemini-1.5-flash -> google-vertex:gemini-1.5-flash) elif model.startswith('vertexai:'): from .vertexai import VertexAIModel @@ -299,6 +312,11 @@ def infer_model(model: Model | KnownModelName) -> Model: from .anthropic import AnthropicModel return AnthropicModel(model[10:]) + # backwards compatibility with old model names (ex, claude-3-5-sonnet-latest -> anthropic:claude-3-5-sonnet-latest) + elif model.startswith('claude'): + from .anthropic import AnthropicModel + + return AnthropicModel(model) else: raise UserError(f'Unknown model: {model}') diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 3ca33cd35c..2522fbdf43 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -111,7 +111,7 @@ async def agent_model( ) def name(self) -> str: - return self.model_name + return f'google-gla:{self.model_name}' class AuthProtocol(Protocol): diff --git a/pydantic_ai_slim/pydantic_ai/models/vertexai.py b/pydantic_ai_slim/pydantic_ai/models/vertexai.py index 794a82dea2..64830817ff 100644 --- a/pydantic_ai_slim/pydantic_ai/models/vertexai.py +++ b/pydantic_ai_slim/pydantic_ai/models/vertexai.py @@ -164,7 +164,7 @@ async def ainit(self) -> tuple[str, BearerTokenAuth]: return url, auth def name(self) -> str: - return f'vertexai:{self.model_name}' + return f'google-vertex:{self.model_name}' # pyright: reportUnknownMemberType=false diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 1695fc7e63..ac6a7c7851 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -26,16 +26,16 @@ def test_infer_str_openai(env: TestEnv): def test_infer_str_gemini(env: TestEnv): env.set('GEMINI_API_KEY', 'via-env-var') - m = infer_model('gemini-1.5-flash') + m = infer_model('google-gla:gemini-1.5-flash') assert isinstance(m, GeminiModel) - assert m.name() == 'gemini-1.5-flash' + assert m.name() == 'google-gla:gemini-1.5-flash' @pytest.mark.skipif(not vertexai_imports_successful(), reason='google-auth not installed') def test_infer_vertexai(env: TestEnv): - m = infer_model('vertexai:gemini-1.5-flash') + m = infer_model('google-vertex:gemini-1.5-flash') assert isinstance(m, VertexAIModel) - assert m.name() == 'vertexai:gemini-1.5-flash' + assert m.name() == 'google-vertex:gemini-1.5-flash' def test_infer_str_unknown(): diff --git a/tests/models/test_vertexai.py b/tests/models/test_vertexai.py index 73635945d6..25fca341ed 100644 --- a/tests/models/test_vertexai.py +++ b/tests/models/test_vertexai.py @@ -40,7 +40,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None): 'publishers/google/models/gemini-1.5-flash:' ) assert model.auth is not None - assert model.name() == snapshot('vertexai:gemini-1.5-flash') + assert model.name() == snapshot('google-vertex:gemini-1.5-flash') class NoOpCredentials: @@ -67,7 +67,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None): 'publishers/google/models/gemini-1.5-flash:' ) assert model.auth is not None - assert model.name() == snapshot('vertexai:gemini-1.5-flash') + assert model.name() == snapshot('google-vertex:gemini-1.5-flash') await model.agent_model(function_tools=[], allow_text_result=True, result_tools=[]) assert model.url is not None diff --git a/tests/test_agent.py b/tests/test_agent.py index 277eeeb521..63fb3e7825 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -780,7 +780,7 @@ def empty(m: list[ModelMessage], _info: AgentInfo) -> ModelResponse: def test_model_requests_blocked(env: TestEnv, set_event_loop: None): env.set('GEMINI_API_KEY', 'foobar') - agent = Agent('gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True) + agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[str, str], defer_model_check=True) with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): agent.run_sync('Hello') @@ -788,7 +788,7 @@ def test_model_requests_blocked(env: TestEnv, set_event_loop: None): def test_override_model(env: TestEnv, set_event_loop: None): env.set('GEMINI_API_KEY', 'foobar') - agent = Agent('gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True) + agent = Agent('google-gla:gemini-1.5-flash', result_type=tuple[int, str], defer_model_check=True) with agent.override(model='test'): result = agent.run_sync('Hello') From 40a74c0082cc40a396024ae5ee512698c2fd429a Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 2 Jan 2025 14:30:03 -0500 Subject: [PATCH 4/8] infer tests --- tests/models/test_model.py | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index ac6a7c7851..1f3d7d12f5 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -12,6 +12,9 @@ with try_import() as vertexai_imports_successful: from pydantic_ai.models.vertexai import VertexAIModel +with try_import() as anthropic_imports_successful: + from pydantic_ai.models.anthropic import AnthropicModel + @pytest.mark.skipif(not openai_imports_successful(), reason='openai not installed') def test_infer_str_openai(env: TestEnv): @@ -24,6 +27,17 @@ def test_infer_str_openai(env: TestEnv): assert m2 is m +@pytest.mark.parametrize('model_name', ['gpt-3.5-turbo', 'o1']) +def test_infer_str_openai_without_provider(env: TestEnv, model_name: str): + env.set('OPENAI_API_KEY', 'via-env-var') + m = infer_model(model_name) # pyright: ignore[reportArgumentType] + assert isinstance(m, OpenAIModel) + assert m.name() == f'openai:{model_name}' + + m2 = infer_model(m) + assert m2 is m + + def test_infer_str_gemini(env: TestEnv): env.set('GEMINI_API_KEY', 'via-env-var') m = infer_model('google-gla:gemini-1.5-flash') @@ -31,6 +45,16 @@ def test_infer_str_gemini(env: TestEnv): assert m.name() == 'google-gla:gemini-1.5-flash' +def test_infer_str_gemini_without_provider(env: TestEnv): + env.set('GEMINI_API_KEY', 'via-env-var') + m = infer_model('gemini-1.5-flash') # pyright: ignore[reportArgumentType] + assert isinstance(m, GeminiModel) + assert m.name() == 'google-gla:gemini-1.5-flash' + + m2 = infer_model(m) + assert m2 is m + + @pytest.mark.skipif(not vertexai_imports_successful(), reason='google-auth not installed') def test_infer_vertexai(env: TestEnv): m = infer_model('google-vertex:gemini-1.5-flash') @@ -38,6 +62,27 @@ def test_infer_vertexai(env: TestEnv): assert m.name() == 'google-vertex:gemini-1.5-flash' +@pytest.mark.skipif(not anthropic_imports_successful(), reason='anthropic not installed') +def test_infer_anthropic(env: TestEnv): + env.set('ANTHROPIC_API_KEY', 'via-env-var') + m = infer_model('anthropic:claude-3-5-haiku-latest') + assert isinstance(m, AnthropicModel) + assert m.name() == 'anthropic:claude-3-5-haiku-latest' + + m2 = infer_model(m) + assert m2 is m + + +def test_infer_anthropic_no_provider(env: TestEnv): + env.set('ANTHROPIC_API_KEY', 'via-env-var') + m = infer_model('claude-3-5-haiku-latest') # pyright: ignore[reportArgumentType] + assert isinstance(m, AnthropicModel) + assert m.name() == 'anthropic:claude-3-5-haiku-latest' + + m2 = infer_model(m) + assert m2 is m + + def test_infer_str_unknown(): with pytest.raises(UserError, match='Unknown model: foobar'): infer_model('foobar') # pyright: ignore[reportArgumentType] From 6606339fae60daf220b7223a3f381bbb09a15b2a Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 2 Jan 2025 14:46:32 -0500 Subject: [PATCH 5/8] parametrized test --- tests/models/test_model.py | 142 +++++++++++++++++++++---------------- 1 file changed, 80 insertions(+), 62 deletions(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 1f3d7d12f5..1fd8ea5c24 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,7 +1,9 @@ +from typing import Callable + import pytest from pydantic_ai import UserError -from pydantic_ai.models import infer_model +from pydantic_ai.models import Model, infer_model from pydantic_ai.models.gemini import GeminiModel from ..conftest import TestEnv, try_import @@ -15,69 +17,85 @@ with try_import() as anthropic_imports_successful: from pydantic_ai.models.anthropic import AnthropicModel +with try_import() as groq_imports_successful: + from pydantic_ai.models.groq import GroqModel + +with try_import() as ollama_imports_successful: + from pydantic_ai.models.ollama import OllamaModel + +with try_import() as mistral_imports_successful: + from pydantic_ai.models.mistral import MistralModel + + +TEST_CASES = [ + ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', OpenAIModel, openai_imports_successful), + ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', OpenAIModel, openai_imports_successful), + ('OPENAI_API_KEY', 'o1', 'openai:o1', OpenAIModel, openai_imports_successful), + ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', GeminiModel, lambda: True), + ('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', GeminiModel, lambda: True), + ( + 'GEMINI_API_KEY', + 'google-vertex:gemini-1.5-flash', + 'google-vertex:gemini-1.5-flash', + VertexAIModel, + vertexai_imports_successful, + ), + ( + 'GEMINI_API_KEY', + 'vertexai:gemini-1.5-flash', + 'google-vertex:gemini-1.5-flash', + VertexAIModel, + vertexai_imports_successful, + ), + ( + 'ANTHROPIC_API_KEY', + 'anthropic:claude-3-5-haiku-latest', + 'anthropic:claude-3-5-haiku-latest', + AnthropicModel, + anthropic_imports_successful, + ), + ( + 'ANTHROPIC_API_KEY', + 'claude-3-5-haiku-latest', + 'anthropic:claude-3-5-haiku-latest', + AnthropicModel, + anthropic_imports_successful, + ), + ( + 'GROQ_API_KEY', + 'groq:llama-3.3-70b-versatile', + 'groq:llama-3.3-70b-versatile', + GroqModel, + groq_imports_successful, + ), + ('OLLAMA_API_KEY', 'ollama:llama3', 'ollama:llama3', OllamaModel, ollama_imports_successful), + ( + 'MISTRAL_API_KEY', + 'mistral:mistral-small-latest', + 'mistral:mistral-small-latest', + MistralModel, + mistral_imports_successful, + ), +] + + +@pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, expected_model, deps_available', TEST_CASES) +def test_infer_model( + env: TestEnv, + mock_api_key: str, + model_name: str, + expected_model_name: str, + expected_model: type[Model], + deps_available: Callable[[], bool], +): + if not deps_available(): + pytest.skip(f'{expected_model.__name__} not installed') + + env.set(mock_api_key, 'via-env-var') -@pytest.mark.skipif(not openai_imports_successful(), reason='openai not installed') -def test_infer_str_openai(env: TestEnv): - env.set('OPENAI_API_KEY', 'via-env-var') - m = infer_model('openai:gpt-3.5-turbo') - assert isinstance(m, OpenAIModel) - assert m.name() == 'openai:gpt-3.5-turbo' - - m2 = infer_model(m) - assert m2 is m - - -@pytest.mark.parametrize('model_name', ['gpt-3.5-turbo', 'o1']) -def test_infer_str_openai_without_provider(env: TestEnv, model_name: str): - env.set('OPENAI_API_KEY', 'via-env-var') m = infer_model(model_name) # pyright: ignore[reportArgumentType] - assert isinstance(m, OpenAIModel) - assert m.name() == f'openai:{model_name}' - - m2 = infer_model(m) - assert m2 is m - - -def test_infer_str_gemini(env: TestEnv): - env.set('GEMINI_API_KEY', 'via-env-var') - m = infer_model('google-gla:gemini-1.5-flash') - assert isinstance(m, GeminiModel) - assert m.name() == 'google-gla:gemini-1.5-flash' - - -def test_infer_str_gemini_without_provider(env: TestEnv): - env.set('GEMINI_API_KEY', 'via-env-var') - m = infer_model('gemini-1.5-flash') # pyright: ignore[reportArgumentType] - assert isinstance(m, GeminiModel) - assert m.name() == 'google-gla:gemini-1.5-flash' - - m2 = infer_model(m) - assert m2 is m - - -@pytest.mark.skipif(not vertexai_imports_successful(), reason='google-auth not installed') -def test_infer_vertexai(env: TestEnv): - m = infer_model('google-vertex:gemini-1.5-flash') - assert isinstance(m, VertexAIModel) - assert m.name() == 'google-vertex:gemini-1.5-flash' - - -@pytest.mark.skipif(not anthropic_imports_successful(), reason='anthropic not installed') -def test_infer_anthropic(env: TestEnv): - env.set('ANTHROPIC_API_KEY', 'via-env-var') - m = infer_model('anthropic:claude-3-5-haiku-latest') - assert isinstance(m, AnthropicModel) - assert m.name() == 'anthropic:claude-3-5-haiku-latest' - - m2 = infer_model(m) - assert m2 is m - - -def test_infer_anthropic_no_provider(env: TestEnv): - env.set('ANTHROPIC_API_KEY', 'via-env-var') - m = infer_model('claude-3-5-haiku-latest') # pyright: ignore[reportArgumentType] - assert isinstance(m, AnthropicModel) - assert m.name() == 'anthropic:claude-3-5-haiku-latest' + assert isinstance(m, expected_model) + assert m.name() == expected_model_name m2 = infer_model(m) assert m2 is m From f377083178e59e933e0a800d38c7da54cf092784 Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Thu, 2 Jan 2025 15:00:37 -0500 Subject: [PATCH 6/8] dynamic imports --- tests/models/test_model.py | 78 ++++++++++++++------------------------ 1 file changed, 28 insertions(+), 50 deletions(-) diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 1fd8ea5c24..625be03aae 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,95 +1,73 @@ -from typing import Callable +from importlib import import_module import pytest from pydantic_ai import UserError -from pydantic_ai.models import Model, infer_model -from pydantic_ai.models.gemini import GeminiModel - -from ..conftest import TestEnv, try_import - -with try_import() as openai_imports_successful: - from pydantic_ai.models.openai import OpenAIModel - -with try_import() as vertexai_imports_successful: - from pydantic_ai.models.vertexai import VertexAIModel - -with try_import() as anthropic_imports_successful: - from pydantic_ai.models.anthropic import AnthropicModel - -with try_import() as groq_imports_successful: - from pydantic_ai.models.groq import GroqModel - -with try_import() as ollama_imports_successful: - from pydantic_ai.models.ollama import OllamaModel - -with try_import() as mistral_imports_successful: - from pydantic_ai.models.mistral import MistralModel +from pydantic_ai.models import infer_model +from ..conftest import TestEnv TEST_CASES = [ - ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', OpenAIModel, openai_imports_successful), - ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', OpenAIModel, openai_imports_successful), - ('OPENAI_API_KEY', 'o1', 'openai:o1', OpenAIModel, openai_imports_successful), - ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', GeminiModel, lambda: True), - ('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', GeminiModel, lambda: True), + ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'), + ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'), + ('OPENAI_API_KEY', 'o1', 'openai:o1', 'openai', 'OpenAIModel'), + ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'), + ('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'), ( 'GEMINI_API_KEY', 'google-vertex:gemini-1.5-flash', 'google-vertex:gemini-1.5-flash', - VertexAIModel, - vertexai_imports_successful, + 'vertexai', + 'VertexAIModel', ), ( 'GEMINI_API_KEY', 'vertexai:gemini-1.5-flash', 'google-vertex:gemini-1.5-flash', - VertexAIModel, - vertexai_imports_successful, + 'vertexai', + 'VertexAIModel', ), ( 'ANTHROPIC_API_KEY', 'anthropic:claude-3-5-haiku-latest', 'anthropic:claude-3-5-haiku-latest', - AnthropicModel, - anthropic_imports_successful, + 'anthropic', + 'AnthropicModel', ), ( 'ANTHROPIC_API_KEY', 'claude-3-5-haiku-latest', 'anthropic:claude-3-5-haiku-latest', - AnthropicModel, - anthropic_imports_successful, + 'anthropic', + 'AnthropicModel', ), ( 'GROQ_API_KEY', 'groq:llama-3.3-70b-versatile', 'groq:llama-3.3-70b-versatile', - GroqModel, - groq_imports_successful, + 'groq', + 'GroqModel', ), - ('OLLAMA_API_KEY', 'ollama:llama3', 'ollama:llama3', OllamaModel, ollama_imports_successful), + ('OLLAMA_API_KEY', 'ollama:llama3', 'ollama:llama3', 'ollama', 'OllamaModel'), ( 'MISTRAL_API_KEY', 'mistral:mistral-small-latest', 'mistral:mistral-small-latest', - MistralModel, - mistral_imports_successful, + 'mistral', + 'MistralModel', ), ] -@pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, expected_model, deps_available', TEST_CASES) +@pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, module_name, model_class_name', TEST_CASES) def test_infer_model( - env: TestEnv, - mock_api_key: str, - model_name: str, - expected_model_name: str, - expected_model: type[Model], - deps_available: Callable[[], bool], + env: TestEnv, mock_api_key: str, model_name: str, expected_model_name: str, module_name: str, model_class_name: str ): - if not deps_available(): - pytest.skip(f'{expected_model.__name__} not installed') + try: + model_module = import_module(f'pydantic_ai.models.{module_name}') + expected_model = getattr(model_module, model_class_name) + except ImportError: + pytest.skip(f'{model_name} dependencies not installed') env.set(mock_api_key, 'via-env-var') From 59cabf7a50c89598963c2816bccb2e3a65d48b0b Mon Sep 17 00:00:00 2001 From: sydney-runkle Date: Tue, 7 Jan 2025 10:19:43 -0500 Subject: [PATCH 7/8] docs updates from samuel's suggestion --- docs/models.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/models.md b/docs/models.md index 3ca1fb65c9..5f0f984dd6 100644 --- a/docs/models.md +++ b/docs/models.md @@ -204,8 +204,8 @@ agent = Agent('google-gla:gemini-1.5-flash') ``` !!! note - The `google-gla` provider prefix represents the "google generative language api" for `GeminiModel`s. - `google-vertex` is used with `VertexAIModel`s. + The `google-gla` provider prefix represents the [google generative language api](https://ai.google.dev/api/all-methods) for `GeminiModel`s. + `google-vertex` is used with [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) for `VertexAIModel`s. Or initialise the model directly with just the model name: From a63b832ea45d13708322101707dff3b2216fd432 Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:49:22 -0500 Subject: [PATCH 8/8] Update docs/models.md Co-authored-by: Samuel Colvin --- docs/models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/models.md b/docs/models.md index 5f0f984dd6..b9265abe07 100644 --- a/docs/models.md +++ b/docs/models.md @@ -204,7 +204,7 @@ agent = Agent('google-gla:gemini-1.5-flash') ``` !!! note - The `google-gla` provider prefix represents the [google generative language api](https://ai.google.dev/api/all-methods) for `GeminiModel`s. + The `google-gla` provider prefix represents the [Google **G**enerative **L**anguage **A**PI](https://ai.google.dev/api/all-methods) for `GeminiModel`s. `google-vertex` is used with [Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) for `VertexAIModel`s. Or initialise the model directly with just the model name: