Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pydantic_ai_slim/pydantic_ai/providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def model_profile(self, model_name: str) -> ModelProfile | None:
def __init__(self, *, anthropic_client: AsyncAnthropicClient | None = None) -> None: ...

@overload
def __init__(self, *, api_key: str | None = None, http_client: httpx.AsyncClient | None = None) -> None: ...
def __init__(
self, *, api_key: str | None = None, base_url: str | None = None, http_client: httpx.AsyncClient | None = None
) -> None: ...

def __init__(
self,
*,
api_key: str | None = None,
base_url: str | None = None,
anthropic_client: AsyncAnthropicClient | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
Expand All @@ -59,6 +62,7 @@ def __init__(
Args:
api_key: The API key to use for authentication, if not provided, the `ANTHROPIC_API_KEY` environment variable
will be used if available.
base_url: The base URL to use for the Anthropic API.
anthropic_client: An existing [`AsyncAnthropic`](https://github.com/anthropics/anthropic-sdk-python)
client to use. If provided, the `api_key` and `http_client` arguments will be ignored.
http_client: An existing `httpx.AsyncClient` to use for making HTTP requests.
Expand All @@ -75,7 +79,7 @@ def __init__(
'to use the Anthropic provider.'
)
if http_client is not None:
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
self._client = AsyncAnthropic(api_key=api_key, base_url=base_url, http_client=http_client)
else:
http_client = cached_async_http_client(provider='anthropic')
self._client = AsyncAnthropic(api_key=api_key, http_client=http_client)
self._client = AsyncAnthropic(api_key=api_key, base_url=base_url, http_client=http_client)
28 changes: 27 additions & 1 deletion pydantic_ai_slim/pydantic_ai/providers/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from groq import AsyncGroq
from openai import AsyncOpenAI

from pydantic_ai.models.anthropic import AsyncAnthropicClient
from pydantic_ai.providers import Provider


Expand Down Expand Up @@ -48,8 +49,17 @@ def gateway_provider(
) -> Provider[GoogleClient]: ...


@overload
def gateway_provider(
upstream_provider: Literal['anthropic'],
*,
api_key: str | None = None,
base_url: str | None = None,
) -> Provider[AsyncAnthropicClient]: ...


def gateway_provider(
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex'] | str,
upstream_provider: Literal['openai', 'openai-chat', 'openai-responses', 'groq', 'google-vertex', 'anthropic'] | str,
*,
# Every provider
api_key: str | None = None,
Expand Down Expand Up @@ -90,6 +100,18 @@ def gateway_provider(
from .groq import GroqProvider

return GroqProvider(api_key=api_key, base_url=urljoin(base_url, 'groq'), http_client=http_client)
elif upstream_provider == 'anthropic':
from anthropic import AsyncAnthropic

from .anthropic import AnthropicProvider

return AnthropicProvider(
anthropic_client=AsyncAnthropic(
auth_token=api_key,
base_url=urljoin(base_url, 'anthropic'),
http_client=http_client,
)
)
elif upstream_provider == 'google-vertex':
from google.genai import Client as GoogleClient

Expand Down Expand Up @@ -140,6 +162,10 @@ def infer_model(model_name: str) -> Model:
from pydantic_ai.models.groq import GroqModel

return GroqModel(model_name, provider=gateway_provider('groq'))
elif upstream_provider == 'anthropic':
from pydantic_ai.models.anthropic import AnthropicModel

return AnthropicModel(model_name, provider=gateway_provider('anthropic'))
elif upstream_provider == 'google-vertex':
from pydantic_ai.models.google import GoogleModel

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
interactions:
- request:
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '166'
content-type:
- application/json
host:
- localhost:8787
method: POST
parsed_body:
max_tokens: 4096
messages:
- content:
- text: What is the capital of France?
type: text
role: user
model: claude-3-5-sonnet-latest
stream: false
uri: http://localhost:8787/anthropic/v1/messages?beta=true
response:
headers:
content-length:
- '500'
content-type:
- application/json
pydantic-ai-gateway-price-estimate:
- 0.0002USD
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
transfer-encoding:
- chunked
parsed_body:
content:
- text: The capital of France is Paris.
type: text
id: msg_015tco2dv5oh9rFq1PcZAduv
model: claude-3-5-sonnet-20241022
role: assistant
stop_reason: end_turn
stop_sequence: null
type: message
usage:
cache_creation:
ephemeral_1h_input_tokens: 0
ephemeral_5m_input_tokens: 0
cache_creation_input_tokens: 0
cache_read_input_tokens: 0
input_tokens: 14
output_tokens: 10
pydantic_ai_gateway:
cost_estimate: 0.00019199999999999998
service_tier: standard
status:
code: 200
message: OK
version: 1
15 changes: 15 additions & 0 deletions tests/providers/test_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..conftest import TestEnv, try_import

with try_import() as imports_successful:
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.models.groq import GroqModel
from pydantic_ai.models.openai import OpenAIChatModel, OpenAIResponsesModel
Expand Down Expand Up @@ -94,6 +95,11 @@ def test_infer_model():
assert model.model_name == 'gemini-1.5-flash'
assert model.system == 'google-vertex'

model = infer_model('anthropic/claude-3-5-sonnet-latest')
assert isinstance(model, AnthropicModel)
assert model.model_name == 'claude-3-5-sonnet-latest'
assert model.system == 'anthropic'

with raises(snapshot('UserError: The model name "gemini-1.5-flash" is not in the format "provider/model_name".')):
infer_model('gemini-1.5-flash')

Expand Down Expand Up @@ -135,3 +141,12 @@ async def test_gateway_provider_with_google_vertex(allow_model_requests: None, g

result = await agent.run('What is the capital of France?')
assert result.output == snapshot('Paris\n')


async def test_gateway_provider_with_anthropic(allow_model_requests: None, gateway_api_key: str):
provider = gateway_provider('anthropic', api_key=gateway_api_key, base_url='http://localhost:8787')
model = AnthropicModel('claude-3-5-sonnet-latest', provider=provider)
agent = Agent(model)

result = await agent.run('What is the capital of France?')
assert result.output == snapshot('The capital of France is Paris.')