|
| 1 | +from importlib import import_module |
| 2 | + |
1 | 3 | import pytest
|
2 | 4 |
|
3 | 5 | from pydantic_ai import UserError
|
4 | 6 | from pydantic_ai.models import infer_model
|
5 |
| -from pydantic_ai.models.gemini import GeminiModel |
6 | 7 |
|
7 |
| -from ..conftest import TestEnv, try_import |
| 8 | +from ..conftest import TestEnv |
| 9 | + |
| 10 | +TEST_CASES = [ |
| 11 | + ('OPENAI_API_KEY', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'), |
| 12 | + ('OPENAI_API_KEY', 'gpt-3.5-turbo', 'openai:gpt-3.5-turbo', 'openai', 'OpenAIModel'), |
| 13 | + ('OPENAI_API_KEY', 'o1', 'openai:o1', 'openai', 'OpenAIModel'), |
| 14 | + ('GEMINI_API_KEY', 'google-gla:gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'), |
| 15 | + ('GEMINI_API_KEY', 'gemini-1.5-flash', 'google-gla:gemini-1.5-flash', 'gemini', 'GeminiModel'), |
| 16 | + ( |
| 17 | + 'GEMINI_API_KEY', |
| 18 | + 'google-vertex:gemini-1.5-flash', |
| 19 | + 'google-vertex:gemini-1.5-flash', |
| 20 | + 'vertexai', |
| 21 | + 'VertexAIModel', |
| 22 | + ), |
| 23 | + ( |
| 24 | + 'GEMINI_API_KEY', |
| 25 | + 'vertexai:gemini-1.5-flash', |
| 26 | + 'google-vertex:gemini-1.5-flash', |
| 27 | + 'vertexai', |
| 28 | + 'VertexAIModel', |
| 29 | + ), |
| 30 | + ( |
| 31 | + 'ANTHROPIC_API_KEY', |
| 32 | + 'anthropic:claude-3-5-haiku-latest', |
| 33 | + 'anthropic:claude-3-5-haiku-latest', |
| 34 | + 'anthropic', |
| 35 | + 'AnthropicModel', |
| 36 | + ), |
| 37 | + ( |
| 38 | + 'ANTHROPIC_API_KEY', |
| 39 | + 'claude-3-5-haiku-latest', |
| 40 | + 'anthropic:claude-3-5-haiku-latest', |
| 41 | + 'anthropic', |
| 42 | + 'AnthropicModel', |
| 43 | + ), |
| 44 | + ( |
| 45 | + 'GROQ_API_KEY', |
| 46 | + 'groq:llama-3.3-70b-versatile', |
| 47 | + 'groq:llama-3.3-70b-versatile', |
| 48 | + 'groq', |
| 49 | + 'GroqModel', |
| 50 | + ), |
| 51 | + ('OLLAMA_API_KEY', 'ollama:llama3', 'ollama:llama3', 'ollama', 'OllamaModel'), |
| 52 | + ( |
| 53 | + 'MISTRAL_API_KEY', |
| 54 | + 'mistral:mistral-small-latest', |
| 55 | + 'mistral:mistral-small-latest', |
| 56 | + 'mistral', |
| 57 | + 'MistralModel', |
| 58 | + ), |
| 59 | +] |
8 | 60 |
|
9 |
| -with try_import() as openai_imports_successful: |
10 |
| - from pydantic_ai.models.openai import OpenAIModel |
11 | 61 |
|
12 |
| -with try_import() as vertexai_imports_successful: |
13 |
| - from pydantic_ai.models.vertexai import VertexAIModel |
| 62 | +@pytest.mark.parametrize('mock_api_key, model_name, expected_model_name, module_name, model_class_name', TEST_CASES) |
| 63 | +def test_infer_model( |
| 64 | + env: TestEnv, mock_api_key: str, model_name: str, expected_model_name: str, module_name: str, model_class_name: str |
| 65 | +): |
| 66 | + try: |
| 67 | + model_module = import_module(f'pydantic_ai.models.{module_name}') |
| 68 | + expected_model = getattr(model_module, model_class_name) |
| 69 | + except ImportError: |
| 70 | + pytest.skip(f'{model_name} dependencies not installed') |
14 | 71 |
|
| 72 | + env.set(mock_api_key, 'via-env-var') |
15 | 73 |
|
16 |
| -@pytest.mark.skipif(not openai_imports_successful(), reason='openai not installed') |
17 |
| -def test_infer_str_openai(env: TestEnv): |
18 |
| - env.set('OPENAI_API_KEY', 'via-env-var') |
19 |
| - m = infer_model('openai:gpt-3.5-turbo') |
20 |
| - assert isinstance(m, OpenAIModel) |
21 |
| - assert m.name() == 'openai:gpt-3.5-turbo' |
| 74 | + m = infer_model(model_name) # pyright: ignore[reportArgumentType] |
| 75 | + assert isinstance(m, expected_model) |
| 76 | + assert m.name() == expected_model_name |
22 | 77 |
|
23 | 78 | m2 = infer_model(m)
|
24 | 79 | assert m2 is m
|
25 | 80 |
|
26 | 81 |
|
27 |
| -def test_infer_str_gemini(env: TestEnv): |
28 |
| - env.set('GEMINI_API_KEY', 'via-env-var') |
29 |
| - m = infer_model('gemini-1.5-flash') |
30 |
| - assert isinstance(m, GeminiModel) |
31 |
| - assert m.name() == 'gemini-1.5-flash' |
32 |
| - |
33 |
| - |
34 |
| -@pytest.mark.skipif(not vertexai_imports_successful(), reason='google-auth not installed') |
35 |
| -def test_infer_vertexai(env: TestEnv): |
36 |
| - m = infer_model('vertexai:gemini-1.5-flash') |
37 |
| - assert isinstance(m, VertexAIModel) |
38 |
| - assert m.name() == 'vertexai:gemini-1.5-flash' |
39 |
| - |
40 |
| - |
41 | 82 | def test_infer_str_unknown():
|
42 | 83 | with pytest.raises(UserError, match='Unknown model: foobar'):
|
43 | 84 | infer_model('foobar') # pyright: ignore[reportArgumentType]
|
0 commit comments