diff --git a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py index d96b70ed55..d02a307d08 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/providers/anthropic.py @@ -45,12 +45,15 @@ def model_profile(self, model_name: str) -> ModelProfile | None: def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ... @overload - def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ... + def __init__( + self, *, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None + ) -> None: ... def __init__( self, *, api_key: str | None = None, + base_url: str | None = None, anthropic_client: AsyncAnthropicClient | None = None, http_client: httpx.AsyncClient | None = None, ) -> None: @@ -59,6 +62,7 @@ def __init__( Args: api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable will be used if available. + base_url: The base URL to use for the Anthropic API. anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python) client to use. If provided, the `api_key` and `http_client` arguments will be ignored. http_client: An existing `httpx.AsyncClient` to use for making HTTP requests. @@ -75,7 +79,7 @@ def __init__( 'to use the Anthropic provider.' ) if http_client is not None: - self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) + self._client = AsyncAnthropic(api_key=api_key, base_url=base_url, http_client=http_client) else: http_client = cached_async_http_client(provider='anthropic') - self._client = AsyncAnthropic(api_key=api_key, http_client=http_client) + self._client = AsyncAnthropic(api_key=api_key, base_url=base_url, http_client=http_client) diff --git a/pydantic_ai_slim/pydantic_ai/providers/gateway.py b/pydantic_ai_slim/pydantic_ai/providers/gateway.py index a82af13063..9d3a8cab59 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/gateway.py +++ b/pydantic_ai_slim/pydantic_ai/providers/gateway.py @@ -16,6 +16,7 @@ from groq import AsyncGroq from openai import AsyncOpenAI + from pydantic_ai.models.anthropic import AsyncAnthropicClient from pydantic_ai.providers import Provider @@ -48,8 +49,17 @@ def gateway_provider( ) -> Provider[GoogleClient]: ... +@overload +def gateway_provider( + upstream_provider: Literal['anthropic'], + *, + api_key: str | None = None, + base_url: str | None = None, +) -> Provider[AsyncAnthropicClient]: ... + + def gateway_provider( - upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex'] | str, + upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic'] | str, *, # Every provider api_key: str | None = None, @@ -90,6 +100,18 @@ def gateway_provider( from .groq import GroqProvider return GroqProvider(api_key=api_key, base_url=urljoin(base_url, 'groq'), http_client=http_client) + elif upstream_provider == 'anthropic': + from anthropic import AsyncAnthropic + + from .anthropic import AnthropicProvider + + return AnthropicProvider( + anthropic_client=AsyncAnthropic( + auth_token=api_key, + base_url=urljoin(base_url, 'anthropic'), + http_client=http_client, + ) + ) elif upstream_provider == 'google-vertex': from google.genai import Client as GoogleClient @@ -140,6 +162,10 @@ def infer_model(model_name: str) -> Model: from pydantic_ai.models.groq import GroqModel return GroqModel(model_name, provider=gateway_provider('groq')) + elif upstream_provider == 'anthropic': + from pydantic_ai.models.anthropic import AnthropicModel + + return AnthropicModel(model_name, provider=gateway_provider('anthropic')) elif upstream_provider == 'google-vertex': from pydantic_ai.models.google import GoogleModel diff --git a/tests/providers/cassettes/test_gateway/test_gateway_provider_with_anthropic.yaml b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_anthropic.yaml new file mode 100644 index 0000000000..d8383164d8 --- /dev/null +++ b/tests/providers/cassettes/test_gateway/test_gateway_provider_with_anthropic.yaml @@ -0,0 +1,63 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '166' + content-type: + - application/json + host: + - localhost:8787 + method: POST + parsed_body: + max_tokens: 4096 + messages: + - content: + - text: What is the capital of France? + type: text + role: user + model: claude-3-5-sonnet-latest + stream: false + uri: http://localhost:8787/anthropic/v1/messages?beta=true + response: + headers: + content-length: + - '500' + content-type: + - application/json + pydantic-ai-gateway-price-estimate: + - 0.0002USD + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: The capital of France is Paris. + type: text + id: msg_015tco2dv5oh9rFq1PcZAduv + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation: + ephemeral_1h_input_tokens: 0 + ephemeral_5m_input_tokens: 0 + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 14 + output_tokens: 10 + pydantic_ai_gateway: + cost_estimate: 0.00019199999999999998 + service_tier: standard + status: + code: 200 + message: OK +version: 1 diff --git a/tests/providers/test_gateway.py b/tests/providers/test_gateway.py index 2213aead7e..1cf4299b6f 100644 --- a/tests/providers/test_gateway.py +++ b/tests/providers/test_gateway.py @@ -13,6 +13,7 @@ from ..conftest import TestEnv, try_import with try_import() as imports_successful: + from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.models.google import GoogleModel from pydantic_ai.models.groq import GroqModel from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel @@ -94,6 +95,11 @@ def test_infer_model(): assert model.model_name == 'gemini-1.5-flash' assert model.system == 'google-vertex' + model = infer_model('anthropic/claude-3-5-sonnet-latest') + assert isinstance(model, AnthropicModel) + assert model.model_name == 'claude-3-5-sonnet-latest' + assert model.system == 'anthropic' + with raises(snapshot('UserError: The model name "gemini-1.5-flash" is not in the format "provider/model_name".')): infer_model('gemini-1.5-flash') @@ -135,3 +141,12 @@ async def test_gateway_provider_with_google_vertex(allow_model_requests: None, g result = await agent.run('What is the capital of France?') assert result.output == snapshot('Paris\n') + + +async def test_gateway_provider_with_anthropic(allow_model_requests: None, gateway_api_key: str): + provider = gateway_provider('anthropic', api_key=gateway_api_key, base_url='http://localhost:8787') + model = AnthropicModel('claude-3-5-sonnet-latest', provider=provider) + agent = Agent(model) + + result = await agent.run('What is the capital of France?') + assert result.output == snapshot('The capital of France is Paris.')