Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 59 additions & 53 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ async def _run_stream( # noqa: C901
async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa: C901
text = ''
tool_calls: list[_messages.ToolCallPart] = []
thinking_parts: list[_messages.ThinkingPart] = []
invisible_parts: bool = False

for part in self.model_response.parts:
if isinstance(part, _messages.TextPart):
Expand All @@ -558,55 +558,65 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # noqa
# Text parts before a built-in tool call are essentially thoughts,
# not part of the final result output, so we reset the accumulated text
text = ''
invisible_parts = True
yield _messages.BuiltinToolCallEvent(part) # pyright: ignore[reportDeprecated]
elif isinstance(part, _messages.BuiltinToolReturnPart):
invisible_parts = True
yield _messages.BuiltinToolResultEvent(part) # pyright: ignore[reportDeprecated]
elif isinstance(part, _messages.ThinkingPart):
thinking_parts.append(part)
invisible_parts = True
else:
assert_never(part)

# At the moment, we prioritize at least executing tool calls if they are present.
# In the future, we'd consider making this configurable at the agent or run level.
# This accounts for cases like anthropic returns that might contain a text response
# and a tool call response, where the text response just indicates the tool call will happen.
if tool_calls:
async for event in self._handle_tool_calls(ctx, tool_calls):
yield event
elif text:
# No events are emitted during the handling of text responses, so we don't need to yield anything
self._next_node = await self._handle_text_response(ctx, text)
elif thinking_parts:
# handle thinking-only responses (responses that contain only ThinkingPart instances)
# this can happen with models that support thinking mode when they don't provide
# actionable output alongside their thinking content.
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
_messages.ModelRequest(
parts=[_messages.RetryPromptPart('Responses without text or tool calls are not permitted.')]
try:
if tool_calls:
async for event in self._handle_tool_calls(ctx, tool_calls):
yield event
elif text:
# No events are emitted during the handling of text responses, so we don't need to yield anything
self._next_node = await self._handle_text_response(ctx, text)
elif invisible_parts:
# handle responses with only thinking or built-in tool parts.
# this can happen with models that support thinking mode when they don't provide
# actionable output alongside their thinking content. so we tell the model to try again.
m = _messages.RetryPromptPart(
content='Responses without text or tool calls are not permitted.',
)
)
else:
# we got an empty response with no tool calls, text, or thinking
# this sometimes happens with anthropic (and perhaps other models)
# when the model has already returned text along side tool calls
# in this scenario, if text responses are allowed, we return text from the most recent model
# response, if any
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
for message in reversed(ctx.state.message_history):
if isinstance(message, _messages.ModelResponse):
text = ''
for part in message.parts:
if isinstance(part, _messages.TextPart):
text += part.content
elif isinstance(part, _messages.BuiltinToolCallPart):
# Text parts before a built-in tool call are essentially thoughts,
# not part of the final result output, so we reset the accumulated text
text = '' # pragma: no cover
if text:
self._next_node = await self._handle_text_response(ctx, text)
return

raise exceptions.UnexpectedModelBehavior('Received empty model response')
raise ToolRetryError(m)
else:
# we got an empty response with no tool calls, text, thinking, or built-in tool calls.
# this sometimes happens with anthropic (and perhaps other models)
# when the model has already returned text along side tool calls
# in this scenario, if text responses are allowed, we return text from the most recent model
# response, if any
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
for message in reversed(ctx.state.message_history):
if isinstance(message, _messages.ModelResponse):
text = ''
for part in message.parts:
if isinstance(part, _messages.TextPart):
text += part.content
elif isinstance(part, _messages.BuiltinToolCallPart):
# Text parts before a built-in tool call are essentially thoughts,
# not part of the final result output, so we reset the accumulated text
text = '' # pragma: no cover
if text:
self._next_node = await self._handle_text_response(ctx, text)
return

# Go back to the model request node with an empty request, which means we'll essentially
# resubmit the most recent request that resulted in an empty response,
# as the empty response and request will not create any items in the API payload,
# in the hope the model will return a non-empty response this time.
ctx.state.increment_retries(ctx.deps.max_result_retries)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[]))
except ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))

self._events_iterator = _run_stream()

Expand Down Expand Up @@ -666,23 +676,19 @@ async def _handle_text_response(
text: str,
) -> ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]]:
output_schema = ctx.deps.output_schema
try:
run_context = build_run_context(ctx)
if isinstance(output_schema, _output.TextOutputSchema):
result_data = await output_schema.process(text, run_context)
else:
m = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please include your response in a tool call',
)
raise ToolRetryError(m)
run_context = build_run_context(ctx)

for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
except ToolRetryError as e:
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
if isinstance(output_schema, _output.TextOutputSchema):
result_data = await output_schema.process(text, run_context)
else:
return self._handle_final_result(ctx, result.FinalResult(result_data), [])
m = _messages.RetryPromptPart(
content='Plain text responses are not permitted, please include your response in a tool call',
)
raise ToolRetryError(m)

for validator in ctx.deps.output_validators:
result_data = await validator.validate(result_data, run_context)
return self._handle_final_result(ctx, result.FinalResult(result_data), [])

__repr__ = dataclasses_no_defaults_repr

Expand Down
13 changes: 1 addition & 12 deletions tests/models/test_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import BaseModel
from typing_extensions import TypedDict

from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
from pydantic_ai import Agent, ModelHTTPError, ModelRetry
from pydantic_ai.builtin_tools import WebSearchTool
from pydantic_ai.messages import (
BinaryContent,
Expand Down Expand Up @@ -533,17 +533,6 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
assert result.is_complete


async def test_no_content(allow_model_requests: None):
stream = chunk([ChoiceDelta()]), chunk([ChoiceDelta()])
mock_client = MockGroq.create_mock_stream(stream)
m = GroqModel('llama-3.3-70b-versatile', provider=GroqProvider(groq_client=mock_client))
agent = Agent(m, output_type=MyTypedDict)

with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
async with agent.run_stream(''):
pass


async def test_no_delta(allow_model_requests: None):
stream = chunk([]), text_chunk('hello '), text_chunk('world')
mock_client = MockGroq.create_mock_stream(stream)
Expand Down
16 changes: 1 addition & 15 deletions tests/models/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from inline_snapshot import snapshot
from typing_extensions import TypedDict

from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior
from pydantic_ai import Agent, ModelRetry
from pydantic_ai.exceptions import ModelHTTPError
from pydantic_ai.messages import (
AudioUrl,
Expand Down Expand Up @@ -601,20 +601,6 @@ async def test_stream_structured_finish_reason(allow_model_requests: None):
assert result.is_complete


async def test_no_content(allow_model_requests: None):
stream = [
chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore
chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore
]
mock_client = MockHuggingFace.create_stream_mock(stream)
m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'))
agent = Agent(m, output_type=MyTypedDict)

with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
async with agent.run_stream(''):
pass


async def test_no_delta(allow_model_requests: None):
stream = [
chunk([]),
Expand Down
11 changes: 0 additions & 11 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,17 +606,6 @@ async def test_stream_text_empty_think_tag_and_text_before_tool_call(allow_model
assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'})


async def test_no_content(allow_model_requests: None):
stream = [chunk([ChoiceDelta()]), chunk([ChoiceDelta()])]
mock_client = MockOpenAI.create_mock_stream(stream)
m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
agent = Agent(m, output_type=MyTypedDict)

with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
async with agent.run_stream(''):
pass


async def test_no_delta(allow_model_requests: None):
stream = [
chunk([]),
Expand Down
75 changes: 70 additions & 5 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,14 +2168,79 @@ def simple_response(_messages: list[ModelMessage], _info: AgentInfo) -> ModelRes
assert result.new_messages() == []


def test_empty_tool_calls():
def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
def test_empty_response():
def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
if len(messages) == 1:
return ModelResponse(parts=[])
else:
return ModelResponse(parts=[TextPart('ok here is text')])

agent = Agent(FunctionModel(llm))

result = agent.run_sync('Hello')

assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='Hello',
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[],
usage=RequestUsage(input_tokens=51),
model_name='function:llm:',
timestamp=IsDatetime(),
),
ModelRequest(parts=[]),
ModelResponse(
parts=[TextPart(content='ok here is text')],
usage=RequestUsage(input_tokens=51, output_tokens=4),
model_name='function:llm:',
timestamp=IsDatetime(),
),
]
)


def test_empty_response_without_recovery():
def llm(messages: list[ModelMessage], _info: AgentInfo) -> ModelResponse:
return ModelResponse(parts=[])

agent = Agent(FunctionModel(empty))
agent = Agent(FunctionModel(llm), output_type=tuple[str, int])

with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
agent.run_sync('Hello')
with capture_run_messages() as messages:
with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'):
agent.run_sync('Hello')

assert messages == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='Hello',
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[],
usage=RequestUsage(input_tokens=51),
model_name='function:llm:',
timestamp=IsDatetime(),
),
ModelRequest(parts=[]),
ModelResponse(
parts=[],
usage=RequestUsage(input_tokens=51),
model_name='function:llm:',
timestamp=IsDatetime(),
),
]
)


def test_unknown_tool():
Expand Down
46 changes: 39 additions & 7 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,15 +400,47 @@ async def ret_a(x: str) -> str:
)


async def test_call_tool_empty():
async def stream_structured_function(_messages: list[ModelMessage], _: AgentInfo) -> AsyncIterator[DeltaToolCalls]:
yield {}
async def test_empty_response():
async def stream_structured_function(
messages: list[ModelMessage], _: AgentInfo
) -> AsyncIterator[DeltaToolCalls | str]:
if len(messages) == 1:
yield {}
else:
yield 'ok here is text'

agent = Agent(FunctionModel(stream_function=stream_structured_function), output_type=tuple[str, int])
agent = Agent(FunctionModel(stream_function=stream_structured_function))

with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'):
async with agent.run_stream('hello'):
pass
async with agent.run_stream('hello') as result:
response = await result.get_output()
assert response == snapshot('ok here is text')
messages = result.all_messages()

assert messages == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='hello',
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[],
usage=RequestUsage(input_tokens=50),
model_name='function::stream_structured_function',
timestamp=IsDatetime(),
),
ModelRequest(parts=[]),
ModelResponse(
parts=[TextPart(content='ok here is text')],
usage=RequestUsage(input_tokens=50, output_tokens=4),
model_name='function::stream_structured_function',
timestamp=IsDatetime(),
),
]
)


async def test_call_tool_wrong_name():
Expand Down