|
4 | 4 |
|
5 | 5 | import pytest
|
6 | 6 | from inline_snapshot import snapshot
|
7 |
| -from pydantic import BaseModel |
| 7 | +from pydantic import BaseModel, field_validator |
8 | 8 |
|
9 | 9 | from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError
|
10 | 10 | from pydantic_ai.messages import (
|
@@ -105,6 +105,50 @@ def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
|
105 | 105 | assert result.all_messages_json().startswith(b'[{"content":"Hello"')
|
106 | 106 |
|
107 | 107 |
|
| 108 | +def test_result_pydantic_model_validation_error(set_event_loop: None): |
| 109 | + def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse: |
| 110 | + assert info.result_tools is not None |
| 111 | + if len(messages) == 1: |
| 112 | + args_json = '{"a": 1, "b": "foo"}' |
| 113 | + else: |
| 114 | + args_json = '{"a": 1, "b": "bar"}' |
| 115 | + return ModelStructuredResponse(calls=[ToolCall.from_json(info.result_tools[0].name, args_json)]) |
| 116 | + |
| 117 | + class Bar(BaseModel): |
| 118 | + a: int |
| 119 | + b: str |
| 120 | + |
| 121 | + @field_validator('b') |
| 122 | + def check_b(cls, v: str) -> str: |
| 123 | + if v == 'foo': |
| 124 | + raise ValueError('must not be foo') |
| 125 | + return v |
| 126 | + |
| 127 | + agent = Agent(FunctionModel(return_model), result_type=Bar) |
| 128 | + |
| 129 | + result = agent.run_sync('Hello') |
| 130 | + assert isinstance(result.data, Bar) |
| 131 | + assert result.data.model_dump() == snapshot({'a': 1, 'b': 'bar'}) |
| 132 | + message_roles = [m.role for m in result.all_messages()] |
| 133 | + assert message_roles == snapshot(['user', 'model-structured-response', 'retry-prompt', 'model-structured-response']) |
| 134 | + |
| 135 | + retry_prompt = result.all_messages()[2] |
| 136 | + assert isinstance(retry_prompt, RetryPrompt) |
| 137 | + assert retry_prompt.model_response() == snapshot("""\ |
| 138 | +1 validation errors: [ |
| 139 | + { |
| 140 | + "type": "value_error", |
| 141 | + "loc": [ |
| 142 | + "b" |
| 143 | + ], |
| 144 | + "msg": "Value error, must not be foo", |
| 145 | + "input": "foo" |
| 146 | + } |
| 147 | +] |
| 148 | +
|
| 149 | +Fix the errors and try again.""") |
| 150 | + |
| 151 | + |
108 | 152 | def test_result_validator(set_event_loop: None):
|
109 | 153 | def return_model(messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
|
110 | 154 | assert info.result_tools is not None
|
|
0 commit comments