Skip to content

Commit ab6ee23

Browse files
committed
fixing remaining tests, breaking TestModel
1 parent d85373a commit ab6ee23

File tree

5 files changed

+97
-120
lines changed

5 files changed

+97
-120
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@ def is_left(self) -> bool:
126126
def whichever(self) -> Left | Right:
127127
return self._left.value if self._left is not None else self.right
128128

129+
def __repr__(self):
130+
if left := self._left:
131+
return f'Either(left={left.value!r})'
132+
else:
133+
return f'Either(right={self.right!r})'
134+
129135

130136
@asynccontextmanager
131137
async def group_by_temporal(

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class TestModel(Model):
5959
agent_model_function_tools: dict[str, ToolDefinition] | None = field(default=None, init=False)
6060
agent_model_allow_text_result: bool | None = field(default=None, init=False)
6161
agent_model_result_tools: dict[str, ToolDefinition] | None = field(default=None, init=False)
62-
_agent_model: TestAgentModel | None = field(default=None, init=False)
62+
agent_model: TestAgentModel | None = field(default=None, init=False)
6363

6464
async def prepare(
6565
self,
@@ -71,40 +71,43 @@ async def prepare(
7171
self.agent_model_function_tools = function_tools
7272
self.agent_model_allow_text_result = allow_text_result
7373
self.agent_model_result_tools = result_tools
74-
if self._agent_model is None:
75-
self._agent_model = self._build_agent()
76-
return self._agent_model
74+
if self.agent_model is None:
75+
self.agent_model = self._build_agent(function_tools, allow_text_result, result_tools)
76+
return self.agent_model
7777

78-
def _build_agent(self):
78+
def _build_agent(
79+
self,
80+
function_tools: dict[str, ToolDefinition],
81+
allow_text_result: bool,
82+
result_tools: dict[str, ToolDefinition] | None,
83+
) -> TestAgentModel:
7984
if self.call_tools == 'all':
80-
tool_calls = [(r.name, r) for r in self.agent_model_function_tools.values()]
85+
tool_calls = [(r.name, r) for r in function_tools.values()]
8186
else:
82-
tools_to_call = (self.agent_model_function_tools[name] for name in self.call_tools)
87+
tools_to_call = (function_tools[name] for name in self.call_tools)
8388
tool_calls = [(r.name, r) for r in tools_to_call]
8489

8590
if self.custom_result_text is not None:
86-
assert self.agent_model_allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
91+
assert allow_text_result, 'Plain response not allowed, but `custom_result_text` is set.'
8792
assert self.custom_result_args is None, 'Cannot set both `custom_result_text` and `custom_result_args`.'
8893
result: _utils.Either[str | None, Any | None] = _utils.Either(left=self.custom_result_text)
8994
elif self.custom_result_args is not None:
90-
assert (
91-
self.agent_model_result_tools is not None
92-
), 'No result tools provided, but `custom_result_args` is set.'
93-
result_tool = next(iter(self.agent_model_result_tools.values()))
95+
assert result_tools is not None, 'No result tools provided, but `custom_result_args` is set.'
96+
result_tool = next(iter(result_tools.values()))
9497

9598
if k := result_tool.outer_typed_dict_key:
9699
result = _utils.Either(right={k: self.custom_result_args})
97100
else:
98101
result = _utils.Either(right=self.custom_result_args)
99-
elif self.agent_model_allow_text_result:
102+
elif allow_text_result:
100103
result = _utils.Either(left=None)
101-
elif self.agent_model_result_tools is not None:
104+
elif result_tools is not None:
102105
result = _utils.Either(right=None)
103106
else:
104107
result = _utils.Either(left=None)
105108

106-
result_tools_list = list(self.agent_model_function_tools.values()) if self.agent_model_function_tools else None
107-
return TestAgentModel(tool_calls, result, result_tools_list, self.seed)
109+
result_tools = list(result_tools.values()) if result_tools is not None else None
110+
return TestAgentModel(tool_calls, result, result_tools, self.seed)
108111

109112
def name(self) -> str:
110113
return 'test-model'

tests/test_agent.py

Lines changed: 57 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pydantic_ai.models.function import AgentInfo, FunctionModel
2525
from pydantic_ai.models.test import TestModel
2626
from pydantic_ai.result import Cost, RunResult
27+
from pydantic_ai.tools import ToolDefinition
2728

2829
from .conftest import IsNow, TestEnv
2930

@@ -34,7 +35,7 @@ def test_result_tuple(set_event_loop: None):
3435
def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
3536
assert info.result_tools is not None
3637
args_json = '{"response": ["foo", "bar"]}'
37-
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools[0].name, args_json)])
38+
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools['final_result'].name, args_json)])
3839

3940
agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str])
4041

@@ -51,7 +52,7 @@ def test_result_pydantic_model(set_event_loop: None):
5152
def return_model(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
5253
assert info.result_tools is not None
5354
args_json = '{"a": 1, "b": "foo"}'
54-
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools[0].name, args_json)])
55+
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools['final_result'].name, args_json)])
5556

5657
agent = Agent(FunctionModel(return_model), result_type=Foo)
5758

@@ -67,7 +68,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
6768
args_json = '{"a": "wrong", "b": "foo"}'
6869
else:
6970
args_json = '{"a": 42, "b": "foo"}'
70-
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools[0].name, args_json)])
71+
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools['final_result'].name, args_json)])
7172

7273
agent = Agent(FunctionModel(return_model), result_type=Foo)
7374

@@ -112,7 +113,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
112113
args_json = '{"a": 41, "b": "foo"}'
113114
else:
114115
args_json = '{"a": 42, "b": "foo"}'
115-
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools[0].name, args_json)])
116+
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools['final_result'].name, args_json)])
116117

117118
agent = Agent(FunctionModel(return_model), result_type=Foo)
118119

@@ -153,7 +154,9 @@ def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
153154
return ModelTextResponse(content='hello')
154155
else:
155156
args_json = '{"response": ["foo", "bar"]}'
156-
return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools[0].name, args_json)])
157+
return ModelStructuredResponse(
158+
calls=[ToolCall.from_json(info.result_tools['final_result'].name, args_json)]
159+
)
157160

158161
agent = Agent(FunctionModel(return_tuple), result_type=tuple[str, str])
159162

@@ -191,27 +194,26 @@ def test_response_tuple(set_event_loop: None):
191194
assert m.agent_model_result_tools is not None
192195
assert len(m.agent_model_result_tools) == 1
193196

194-
# to match the protocol, we just extract the attributes we care about
195-
fields = 'name', 'description', 'parameters_json_schema', 'outer_typed_dict_key'
196-
agent_model_result_tool = {f: getattr(m.agent_model_result_tools[0], f) for f in fields}
197-
assert agent_model_result_tool == snapshot(
197+
assert m.agent_model_result_tools == snapshot(
198198
{
199-
'name': 'final_result',
200-
'description': 'The final response which ends this conversation',
201-
'parameters_json_schema': {
202-
'properties': {
203-
'response': {
204-
'maxItems': 2,
205-
'minItems': 2,
206-
'prefixItems': [{'type': 'string'}, {'type': 'string'}],
207-
'title': 'Response',
208-
'type': 'array',
209-
}
199+
'final_result': ToolDefinition(
200+
name='final_result',
201+
description='The final response which ends this conversation',
202+
parameters_json_schema={
203+
'properties': {
204+
'response': {
205+
'maxItems': 2,
206+
'minItems': 2,
207+
'prefixItems': [{'type': 'string'}, {'type': 'string'}],
208+
'title': 'Response',
209+
'type': 'array',
210+
}
211+
},
212+
'required': ['response'],
213+
'type': 'object',
210214
},
211-
'required': ['response'],
212-
'type': 'object',
213-
},
214-
'outer_typed_dict_key': 'response',
215+
outer_typed_dict_key='response',
216+
)
215217
}
216218
)
217219

@@ -250,20 +252,21 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
250252
assert m.agent_model_result_tools is not None
251253
assert len(m.agent_model_result_tools) == 1
252254

253-
# to match the protocol, we just extract the attributes we care about
254-
fields = 'name', 'description', 'parameters_json_schema', 'outer_typed_dict_key'
255-
agent_model_result_tool = {f: getattr(m.agent_model_result_tools[0], f) for f in fields}
256-
assert agent_model_result_tool == snapshot(
255+
assert m.agent_model_result_tools == snapshot(
257256
{
258-
'name': 'final_result',
259-
'description': 'The final response which ends this conversation',
260-
'parameters_json_schema': {
261-
'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'string'}},
262-
'title': 'Foo',
263-
'required': ['a', 'b'],
264-
'type': 'object',
265-
},
266-
'outer_typed_dict_key': None,
257+
'final_result': ToolDefinition(
258+
name='final_result',
259+
description='The final response which ends this conversation',
260+
parameters_json_schema={
261+
'properties': {
262+
'a': {'title': 'A', 'type': 'integer'},
263+
'b': {'title': 'B', 'type': 'string'},
264+
},
265+
'required': ['a', 'b'],
266+
'title': 'Foo',
267+
'type': 'object',
268+
},
269+
)
267270
}
268271
)
269272

@@ -324,15 +327,12 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
324327
assert m.agent_model_result_tools is not None
325328
assert len(m.agent_model_result_tools) == 2
326329

327-
# to match the protocol, we just extract the attributes we care about
328-
fields = 'name', 'description', 'parameters_json_schema', 'outer_typed_dict_key'
329-
agent_model_result_tools = [{f: getattr(t, f) for f in fields} for t in m.agent_model_result_tools]
330-
assert agent_model_result_tools == snapshot(
331-
[
332-
{
333-
'name': 'final_result_Foo',
334-
'description': 'Foo: The final response which ends this conversation',
335-
'parameters_json_schema': {
330+
assert m.agent_model_result_tools == snapshot(
331+
{
332+
'final_result_Foo': ToolDefinition(
333+
name='final_result_Foo',
334+
description='Foo: The final response which ends this conversation',
335+
parameters_json_schema={
336336
'properties': {
337337
'a': {'title': 'A', 'type': 'integer'},
338338
'b': {'title': 'B', 'type': 'string'},
@@ -341,20 +341,18 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
341341
'title': 'Foo',
342342
'type': 'object',
343343
},
344-
'outer_typed_dict_key': None,
345-
},
346-
{
347-
'name': 'final_result_Bar',
348-
'description': 'This is a bar model.',
349-
'parameters_json_schema': {
344+
),
345+
'final_result_Bar': ToolDefinition(
346+
name='final_result_Bar',
347+
description='This is a bar model.',
348+
parameters_json_schema={
350349
'properties': {'b': {'title': 'B', 'type': 'string'}},
351350
'required': ['b'],
352351
'title': 'Bar',
353352
'type': 'object',
354353
},
355-
'outer_typed_dict_key': None,
356-
},
357-
]
354+
),
355+
}
358356
)
359357

360358
result = agent.run_sync('Hello', model=TestModel(seed=1))
@@ -383,6 +381,7 @@ async def ret_a(x: str) -> str:
383381
ModelTextResponse(content='{"ret_a":"a-apple"}', timestamp=IsNow(tz=timezone.utc)),
384382
]
385383
)
384+
m.agent_model = None
386385

387386
# if we pass new_messages, system prompt is inserted before the message_history messages
388387
result2 = agent.run_sync('Hello again', message_history=result1.new_messages())
@@ -415,6 +414,8 @@ async def ret_a(x: str) -> str:
415414
assert new_msg_roles == snapshot(['user', 'model-structured-response', 'tool-return', 'model-text-response'])
416415
assert result2.new_messages_json().startswith(b'[{"content":"Hello again",')
417416

417+
m.agent_model = None
418+
418419
# if we pass all_messages, system prompt is NOT inserted before the message_history messages,
419420
# so only one system prompt
420421
result3 = agent.run_sync('Hello again', message_history=result1.all_messages())
@@ -540,6 +541,7 @@ async def make_request() -> str:
540541
for _ in range(2):
541542
result = agent.run_sync('Hello')
542543
assert result.data == '{"make_request":"200"}'
544+
agent.model.agent_model = None
543545

544546

545547
async def test_agent_name():

tests/test_deps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@ async def example_tool(ctx: RunContext[MyDeps]) -> str:
2121
def test_deps_used(set_event_loop: None):
2222
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
2323
assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}'
24+
agent.model.agent_model = None
2425

2526

2627
def test_deps_override(set_event_loop: None):
2728
with agent.override(deps=MyDeps(foo=3, bar=4)):
2829
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
2930
assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}'
31+
agent.model.agent_model = None
3032

3133
with agent.override(deps=MyDeps(foo=5, bar=6)):
3234
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
3335
assert result.data == '{"example_tool":"MyDeps(foo=5, bar=6)"}'
36+
agent.model.agent_model = None
3437

3538
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
3639
assert result.data == '{"example_tool":"MyDeps(foo=3, bar=4)"}'
40+
agent.model.agent_model = None
3741

3842
result = agent.run_sync('foobar', deps=MyDeps(foo=1, bar=2))
3943
assert result.data == '{"example_tool":"MyDeps(foo=1, bar=2)"}'

0 commit comments

Comments
 (0)