24
24
from pydantic_ai .models .function import AgentInfo , FunctionModel
25
25
from pydantic_ai .models .test import TestModel
26
26
from pydantic_ai .result import Cost , RunResult
27
+ from pydantic_ai .tools import ToolDefinition
27
28
28
29
from .conftest import IsNow , TestEnv
29
30
@@ -34,7 +35,7 @@ def test_result_tuple(set_event_loop: None):
34
35
def return_tuple (_ : list [Message ], info : AgentInfo ) -> ModelAnyResponse :
35
36
assert info .result_tools is not None
36
37
args_json = '{"response": ["foo", "bar"]}'
37
- return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools [0 ].name , args_json )])
38
+ return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools ['final_result' ].name , args_json )])
38
39
39
40
agent = Agent (FunctionModel (return_tuple ), result_type = tuple [str , str ])
40
41
@@ -51,7 +52,7 @@ def test_result_pydantic_model(set_event_loop: None):
51
52
def return_model (_ : list [Message ], info : AgentInfo ) -> ModelAnyResponse :
52
53
assert info .result_tools is not None
53
54
args_json = '{"a": 1, "b": "foo"}'
54
- return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools [0 ].name , args_json )])
55
+ return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools ['final_result' ].name , args_json )])
55
56
56
57
agent = Agent (FunctionModel (return_model ), result_type = Foo )
57
58
@@ -67,7 +68,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
67
68
args_json = '{"a": "wrong", "b": "foo"}'
68
69
else :
69
70
args_json = '{"a": 42, "b": "foo"}'
70
- return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools [0 ].name , args_json )])
71
+ return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools ['final_result' ].name , args_json )])
71
72
72
73
agent = Agent (FunctionModel (return_model ), result_type = Foo )
73
74
@@ -112,7 +113,7 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
112
113
args_json = '{"a": 41, "b": "foo"}'
113
114
else :
114
115
args_json = '{"a": 42, "b": "foo"}'
115
- return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools [0 ].name , args_json )])
116
+ return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools ['final_result' ].name , args_json )])
116
117
117
118
agent = Agent (FunctionModel (return_model ), result_type = Foo )
118
119
@@ -153,7 +154,9 @@ def return_tuple(_: list[Message], info: AgentInfo) -> ModelAnyResponse:
153
154
return ModelTextResponse (content = 'hello' )
154
155
else :
155
156
args_json = '{"response": ["foo", "bar"]}'
156
- return ModelStructuredResponse (calls = [ToolCall .from_json (info .result_tools [0 ].name , args_json )])
157
+ return ModelStructuredResponse (
158
+ calls = [ToolCall .from_json (info .result_tools ['final_result' ].name , args_json )]
159
+ )
157
160
158
161
agent = Agent (FunctionModel (return_tuple ), result_type = tuple [str , str ])
159
162
@@ -191,27 +194,26 @@ def test_response_tuple(set_event_loop: None):
191
194
assert m .agent_model_result_tools is not None
192
195
assert len (m .agent_model_result_tools ) == 1
193
196
194
- # to match the protocol, we just extract the attributes we care about
195
- fields = 'name' , 'description' , 'parameters_json_schema' , 'outer_typed_dict_key'
196
- agent_model_result_tool = {f : getattr (m .agent_model_result_tools [0 ], f ) for f in fields }
197
- assert agent_model_result_tool == snapshot (
197
+ assert m .agent_model_result_tools == snapshot (
198
198
{
199
- 'name' : 'final_result' ,
200
- 'description' : 'The final response which ends this conversation' ,
201
- 'parameters_json_schema' : {
202
- 'properties' : {
203
- 'response' : {
204
- 'maxItems' : 2 ,
205
- 'minItems' : 2 ,
206
- 'prefixItems' : [{'type' : 'string' }, {'type' : 'string' }],
207
- 'title' : 'Response' ,
208
- 'type' : 'array' ,
209
- }
199
+ 'final_result' : ToolDefinition (
200
+ name = 'final_result' ,
201
+ description = 'The final response which ends this conversation' ,
202
+ parameters_json_schema = {
203
+ 'properties' : {
204
+ 'response' : {
205
+ 'maxItems' : 2 ,
206
+ 'minItems' : 2 ,
207
+ 'prefixItems' : [{'type' : 'string' }, {'type' : 'string' }],
208
+ 'title' : 'Response' ,
209
+ 'type' : 'array' ,
210
+ }
211
+ },
212
+ 'required' : ['response' ],
213
+ 'type' : 'object' ,
210
214
},
211
- 'required' : ['response' ],
212
- 'type' : 'object' ,
213
- },
214
- 'outer_typed_dict_key' : 'response' ,
215
+ outer_typed_dict_key = 'response' ,
216
+ )
215
217
}
216
218
)
217
219
@@ -250,20 +252,21 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
250
252
assert m .agent_model_result_tools is not None
251
253
assert len (m .agent_model_result_tools ) == 1
252
254
253
- # to match the protocol, we just extract the attributes we care about
254
- fields = 'name' , 'description' , 'parameters_json_schema' , 'outer_typed_dict_key'
255
- agent_model_result_tool = {f : getattr (m .agent_model_result_tools [0 ], f ) for f in fields }
256
- assert agent_model_result_tool == snapshot (
255
+ assert m .agent_model_result_tools == snapshot (
257
256
{
258
- 'name' : 'final_result' ,
259
- 'description' : 'The final response which ends this conversation' ,
260
- 'parameters_json_schema' : {
261
- 'properties' : {'a' : {'title' : 'A' , 'type' : 'integer' }, 'b' : {'title' : 'B' , 'type' : 'string' }},
262
- 'title' : 'Foo' ,
263
- 'required' : ['a' , 'b' ],
264
- 'type' : 'object' ,
265
- },
266
- 'outer_typed_dict_key' : None ,
257
+ 'final_result' : ToolDefinition (
258
+ name = 'final_result' ,
259
+ description = 'The final response which ends this conversation' ,
260
+ parameters_json_schema = {
261
+ 'properties' : {
262
+ 'a' : {'title' : 'A' , 'type' : 'integer' },
263
+ 'b' : {'title' : 'B' , 'type' : 'string' },
264
+ },
265
+ 'required' : ['a' , 'b' ],
266
+ 'title' : 'Foo' ,
267
+ 'type' : 'object' ,
268
+ },
269
+ )
267
270
}
268
271
)
269
272
@@ -324,15 +327,12 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
324
327
assert m .agent_model_result_tools is not None
325
328
assert len (m .agent_model_result_tools ) == 2
326
329
327
- # to match the protocol, we just extract the attributes we care about
328
- fields = 'name' , 'description' , 'parameters_json_schema' , 'outer_typed_dict_key'
329
- agent_model_result_tools = [{f : getattr (t , f ) for f in fields } for t in m .agent_model_result_tools ]
330
- assert agent_model_result_tools == snapshot (
331
- [
332
- {
333
- 'name' : 'final_result_Foo' ,
334
- 'description' : 'Foo: The final response which ends this conversation' ,
335
- 'parameters_json_schema' : {
330
+ assert m .agent_model_result_tools == snapshot (
331
+ {
332
+ 'final_result_Foo' : ToolDefinition (
333
+ name = 'final_result_Foo' ,
334
+ description = 'Foo: The final response which ends this conversation' ,
335
+ parameters_json_schema = {
336
336
'properties' : {
337
337
'a' : {'title' : 'A' , 'type' : 'integer' },
338
338
'b' : {'title' : 'B' , 'type' : 'string' },
@@ -341,20 +341,18 @@ def validate_result(ctx: RunContext[None], r: Any) -> Any:
341
341
'title' : 'Foo' ,
342
342
'type' : 'object' ,
343
343
},
344
- 'outer_typed_dict_key' : None ,
345
- },
346
- {
347
- 'name' : 'final_result_Bar' ,
348
- 'description' : 'This is a bar model.' ,
349
- 'parameters_json_schema' : {
344
+ ),
345
+ 'final_result_Bar' : ToolDefinition (
346
+ name = 'final_result_Bar' ,
347
+ description = 'This is a bar model.' ,
348
+ parameters_json_schema = {
350
349
'properties' : {'b' : {'title' : 'B' , 'type' : 'string' }},
351
350
'required' : ['b' ],
352
351
'title' : 'Bar' ,
353
352
'type' : 'object' ,
354
353
},
355
- 'outer_typed_dict_key' : None ,
356
- },
357
- ]
354
+ ),
355
+ }
358
356
)
359
357
360
358
result = agent .run_sync ('Hello' , model = TestModel (seed = 1 ))
@@ -383,6 +381,7 @@ async def ret_a(x: str) -> str:
383
381
ModelTextResponse (content = '{"ret_a":"a-apple"}' , timestamp = IsNow (tz = timezone .utc )),
384
382
]
385
383
)
384
+ m .agent_model = None
386
385
387
386
# if we pass new_messages, system prompt is inserted before the message_history messages
388
387
result2 = agent .run_sync ('Hello again' , message_history = result1 .new_messages ())
@@ -415,6 +414,8 @@ async def ret_a(x: str) -> str:
415
414
assert new_msg_roles == snapshot (['user' , 'model-structured-response' , 'tool-return' , 'model-text-response' ])
416
415
assert result2 .new_messages_json ().startswith (b'[{"content":"Hello again",' )
417
416
417
+ m .agent_model = None
418
+
418
419
# if we pass all_messages, system prompt is NOT inserted before the message_history messages,
419
420
# so only one system prompt
420
421
result3 = agent .run_sync ('Hello again' , message_history = result1 .all_messages ())
@@ -540,6 +541,7 @@ async def make_request() -> str:
540
541
for _ in range (2 ):
541
542
result = agent .run_sync ('Hello' )
542
543
assert result .data == '{"make_request":"200"}'
544
+ agent .model .agent_model = None
543
545
544
546
545
547
async def test_agent_name ():
0 commit comments