Skip to content

Commit cdf7939

Browse files
committed
revert agent_model() name change
1 parent ab6ee23 commit cdf7939

File tree

12 files changed

+34
-32
lines changed

12 files changed

+34
-32
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ async def add_tool(key: str, tool: Tool[AgentDeps]) -> None:
669669

670670
await asyncio.gather(*(add_tool(k, t) for k, t in self._function_tools.items()))
671671

672-
return await model.prepare(
672+
return await model.agent_model(
673673
function_tools=function_tools,
674674
allow_text_result=self._allow_text_result,
675675
result_tools=self._result_schema.tool_defs() if self._result_schema is not None else None,

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class Model(ABC):
6161
"""Abstract class for a model."""
6262

6363
@abstractmethod
64-
async def prepare(
64+
async def agent_model(
6565
self,
6666
*,
6767
function_tools: dict[str, ToolDefinition],

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, function: FunctionDef | None = None, *, stream_function: Stre
5151
self.function = function
5252
self.stream_function = stream_function
5353

54-
async def prepare(
54+
async def agent_model(
5555
self,
5656
*,
5757
function_tools: dict[str, ToolDefinition],

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(
8888
self.http_client = http_client or cached_async_http_client()
8989
self.url = url_template.format(model=model_name)
9090

91-
async def prepare(
91+
async def agent_model(
9292
self,
9393
*,
9494
function_tools: dict[str, ToolDefinition],

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
else:
108108
self.client = AsyncGroq(api_key=api_key, http_client=cached_async_http_client())
109109

110-
async def prepare(
110+
async def agent_model(
111111
self,
112112
*,
113113
function_tools: dict[str, ToolDefinition],

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
else:
8888
self.client = AsyncOpenAI(api_key=api_key, http_client=cached_async_http_client())
8989

90-
async def prepare(
90+
async def agent_model(
9191
self,
9292
*,
9393
function_tools: dict[str, ToolDefinition],

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ class TestModel(Model):
5959
agent_model_function_tools: dict[str, ToolDefinition] | None = field(default=None, init=False)
6060
agent_model_allow_text_result: bool | None = field(default=None, init=False)
6161
agent_model_result_tools: dict[str, ToolDefinition] | None = field(default=None, init=False)
62-
agent_model: TestAgentModel | None = field(default=None, init=False)
62+
agent_model_cache: TestAgentModel | None = field(default=None, init=False)
6363

64-
async def prepare(
64+
async def agent_model(
6565
self,
6666
*,
6767
function_tools: dict[str, ToolDefinition],
@@ -71,9 +71,9 @@ async def prepare(
7171
self.agent_model_function_tools = function_tools
7272
self.agent_model_allow_text_result = allow_text_result
7373
self.agent_model_result_tools = result_tools
74-
if self.agent_model is None:
75-
self.agent_model = self._build_agent(function_tools, allow_text_result, result_tools)
76-
return self.agent_model
74+
if self.agent_model_cache is None:
75+
self.agent_model_cache = self._build_agent(function_tools, allow_text_result, result_tools)
76+
return self.agent_model_cache
7777

7878
def _build_agent(
7979
self,

pydantic_ai_slim/pydantic_ai/models/vertexai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
self.auth = None
108108
self.url = None
109109

110-
async def prepare(
110+
async def agent_model(
111111
self,
112112
*,
113113
function_tools: dict[str, ToolDefinition],

tests/models/test_gemini.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_api_key_empty(env: TestEnv):
7777

7878
async def test_agent_model_simple(allow_model_requests: None):
7979
m = GeminiModel('gemini-1.5-flash', api_key='via-arg')
80-
agent_model = await m.prepare(function_tools={}, allow_text_result=True, result_tools=None)
80+
agent_model = await m.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
8181
assert isinstance(agent_model.http_client, httpx.AsyncClient)
8282
assert agent_model.model_name == 'gemini-1.5-flash'
8383
assert isinstance(agent_model.auth, ApiKeyAuth)
@@ -110,7 +110,9 @@ async def test_agent_model_tools(allow_model_requests: None):
110110
'This is the tool for the final Result',
111111
{'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}, 'required': ['spam']},
112112
)
113-
agent_model = await m.prepare(function_tools=tools, allow_text_result=True, result_tools={'result': result_tool})
113+
agent_model = await m.agent_model(
114+
function_tools=tools, allow_text_result=True, result_tools={'result': result_tool}
115+
)
114116
assert agent_model.tools == snapshot(
115117
_GeminiTools(
116118
function_declarations=[
@@ -149,7 +151,7 @@ async def test_require_response_tool(allow_model_requests: None):
149151
'This is the tool for the final Result',
150152
{'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}},
151153
)
152-
agent_model = await m.prepare(function_tools={}, allow_text_result=False, result_tools={'result': result_tool})
154+
agent_model = await m.agent_model(function_tools={}, allow_text_result=False, result_tools={'result': result_tool})
153155
assert agent_model.tools == snapshot(
154156
_GeminiTools(
155157
function_declarations=[
@@ -206,7 +208,7 @@ class Locations(BaseModel):
206208
'This is the tool for the final Result',
207209
json_schema,
208210
)
209-
agent_model = await m.prepare(function_tools={}, allow_text_result=True, result_tools={'result': result_tool})
211+
agent_model = await m.agent_model(function_tools={}, allow_text_result=True, result_tools={'result': result_tool})
210212
assert agent_model.tools == snapshot(
211213
_GeminiTools(
212214
function_declarations=[
@@ -252,7 +254,7 @@ class Locations(BaseModel):
252254
'This is the tool for the final Result',
253255
json_schema,
254256
)
255-
agent_model = await m.prepare(function_tools={}, allow_text_result=True, result_tools={'result': result_tool})
257+
agent_model = await m.agent_model(function_tools={}, allow_text_result=True, result_tools={'result': result_tool})
256258
assert agent_model.tools == snapshot(
257259
_GeminiTools(
258260
function_declarations=[
@@ -316,7 +318,7 @@ class Location(BaseModel):
316318
json_schema,
317319
)
318320
with pytest.raises(UserError, match=r'Recursive `\$ref`s in JSON Schema are not supported by Gemini'):
319-
await m.prepare(function_tools={}, allow_text_result=True, result_tools={'result': result_tool})
321+
await m.agent_model(function_tools={}, allow_text_result=True, result_tools={'result': result_tool})
320322

321323

322324
@dataclass

tests/models/test_vertexai.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def test_init_service_account(tmp_path: Path, allow_model_requests: None):
3333
assert model.url is None
3434
assert model.auth is None
3535

36-
await model.prepare(function_tools={}, allow_text_result=True, result_tools=None)
36+
await model.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
3737

3838
assert model.url == snapshot(
3939
'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
@@ -58,7 +58,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
5858

5959
assert patch.call_count == 0
6060

61-
await model.prepare(function_tools={}, allow_text_result=True, result_tools=None)
61+
await model.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
6262

6363
assert patch.call_count == 1
6464

@@ -69,7 +69,7 @@ async def test_init_env(mocker: MockerFixture, allow_model_requests: None):
6969
assert model.auth is not None
7070
assert model.name() == snapshot('vertexai:gemini-1.5-flash')
7171

72-
await model.prepare(function_tools={}, allow_text_result=True, result_tools=None)
72+
await model.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
7373
assert model.url is not None
7474
assert model.auth is not None
7575
assert patch.call_count == 1
@@ -83,7 +83,7 @@ async def test_init_right_project_id(tmp_path: Path, allow_model_requests: None)
8383
assert model.url is None
8484
assert model.auth is None
8585

86-
await model.prepare(function_tools={}, allow_text_result=True, result_tools=None)
86+
await model.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
8787

8888
assert model.url == snapshot(
8989
'https://us-central1-aiplatform.googleapis.com/v1/projects/my-project-id/locations/us-central1/'
@@ -99,7 +99,7 @@ async def test_init_service_account_wrong_project_id(tmp_path: Path):
9999
model = VertexAIModel('gemini-1.5-flash', service_account_file=service_account_path, project_id='different')
100100

101101
with pytest.raises(UserError) as exc_info:
102-
await model.prepare(function_tools={}, allow_text_result=True, result_tools=None)
102+
await model.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
103103
assert str(exc_info.value) == snapshot(
104104
"The project_id you provided does not match the one from service account file: 'different' != 'my-project-id'"
105105
)
@@ -110,7 +110,7 @@ async def test_init_env_wrong_project_id(mocker: MockerFixture):
110110
model = VertexAIModel('gemini-1.5-flash', project_id='different')
111111

112112
with pytest.raises(UserError) as exc_info:
113-
await model.prepare(function_tools={}, allow_text_result=True, result_tools=None)
113+
await model.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
114114
assert str(exc_info.value) == snapshot(
115115
"The project_id you provided does not match the one from `google.auth.default()`: 'different' != 'my-project-id'"
116116
)
@@ -124,7 +124,7 @@ async def test_init_env_no_project_id(mocker: MockerFixture):
124124
model = VertexAIModel('gemini-1.5-flash')
125125

126126
with pytest.raises(UserError) as exc_info:
127-
await model.prepare(function_tools={}, allow_text_result=True, result_tools=None)
127+
await model.agent_model(function_tools={}, allow_text_result=True, result_tools=None)
128128
assert str(exc_info.value) == snapshot('No project_id provided and none found in `google.auth.default()`')
129129

130130

0 commit comments

Comments
 (0)